// s05: Tool Hook Boundary — mini Pi 的第 5 版 // // 在工具执行的前后各留一个插口:执行前可以拦,执行后可以改结果。 // 词汇边界:本章新增 beforeToolCall / afterToolCall / ToolHooks / BeforeToolCallResult / executeToolCall / allow / block。 // 关键:hook 是外层装饰,ToolRegistry 本身不变(R2);执行+错误捕获收口到 executeToolCall。 declare const process: { argv: string[]; exitCode?: number; }; // —— 停止原因(s04 起)—— export type StopReason = "stop" | "toolUse" | "error"; // —— 消息(s01 起 + s04 的 ToolResultMessage)—— export type UserMessage = { role: "user"; content: string }; export type AssistantMessage = { role: "assistant"; content: string; stopReason: StopReason }; export type ToolResultMessage = { role: "toolResult"; toolCallId: string; content: string }; export type AgentMessage = UserMessage | AssistantMessage | ToolResultMessage; export type AgentState = { messages: AgentMessage[] }; // —— 工具契约(s02 起)—— export type ToolSpec = { name: string; description: string; input: Record }; export type ToolHandler = (input: Record) => string; export type ToolCall = { id: string; name: string; input: Record }; export type Tool = { spec: ToolSpec; handler: ToolHandler }; export class ToolRegistry { private tools = new Map(); register(tool: Tool): void { this.tools.set(tool.spec.name, tool); } getSpecs(): ToolSpec[] { return [...this.tools.values()].map((tool) => tool.spec); } run(call: ToolCall): string { const tool = this.tools.get(call.name); if (!tool) return `unknown tool: ${call.name}`; return tool.handler(call.input); } } // —— provider 对外(s04 起)—— export type ProviderMessage = | { role: "user" | "assistant"; content: string } | { role: "toolResult"; toolCallId: string; content: string }; export type ProviderInput = { messages: ProviderMessage[]; tools: ToolSpec[] }; export type ProviderEvent = | { type: "message_start" } | { type: "text_delta"; text: string } | { type: "tool_call"; call: ToolCall } | { type: "message_end"; stopReason: StopReason }; export interface Provider { stream(input: ProviderInput): AsyncGenerator; } export type Output = { log(line: string): void }; export function createConsoleOutput(): Output { return { log: (line) => console.log(line) }; } // ============ s05 新增:执行前后的两个插口 ============ export type BeforeToolCallResult = | { type: "allow" } | { type: "block"; reason: string }; export type ToolHooks = { beforeToolCall?: (call: ToolCall) => BeforeToolCallResult; afterToolCall?: (call: ToolCall, result: string) => string; }; // 把一次工具执行串成 before → run → after 三段。 // before 可以拦下(block),after 可以改写结果。中间的 run 仍是 s04 的 registry.run。 // 工具抛错也在这里收口(R4),不向上抛。 export function executeToolCall( registry: ToolRegistry, hooks: ToolHooks, call: ToolCall, ): ToolResultMessage { const before = hooks.beforeToolCall?.(call) ?? { type: "allow" }; if (before.type === "block") { return { role: "toolResult", toolCallId: call.id, content: `blocked: ${before.reason}`, }; } let result: string; try { result = registry.run(call); } catch (error) { result = `error: ${error instanceof Error ? error.message : String(error)}`; } const finalResult = hooks.afterToolCall?.(call, result) ?? result; return { role: "toolResult", toolCallId: call.id, content: finalResult }; } // ============ 构造函数 ============ export function createInitialState(): AgentState { return { messages: [] }; } export function createUserMessage(content: string): UserMessage { return { role: "user", content }; } export function buildProviderInput(state: AgentState, registry: ToolRegistry): ProviderInput { return { messages: state.messages.map((message) => { if (message.role === "toolResult") { return { role: "toolResult", toolCallId: message.toolCallId, content: message.content }; } return { role: message.role, content: message.content }; }), tools: registry.getSpecs(), }; } // ============ 工具循环(s04 起。s05:加 hooks,tool_call 走 executeToolCall)============ const MAX_TURNS = 8; export async function runEventedToolLoop( state: AgentState, provider: Provider, registry: ToolRegistry, hooks: ToolHooks, userInput: string, output: Output, ): Promise { state.messages.push(createUserMessage(userInput)); let turns = 0; while (true) { turns += 1; if (turns > MAX_TURNS) { const stopped: AssistantMessage = { role: "assistant", content: "(达到最大轮次,停止)", stopReason: "stop", }; state.messages.push(stopped); return stopped; } const providerInput = buildProviderInput(state, registry); let content = ""; let stopReason: StopReason = "stop"; let sawToolCall = false; for await (const event of provider.stream(providerInput)) { if (event.type === "message_start") { output.log("message_start"); } else if (event.type === "text_delta") { output.log(`text_delta: ${event.text}`); content += event.text; } else if (event.type === "tool_call") { sawToolCall = true; output.log(`tool_call: ${event.call.name}`); // s05:执行交给 executeToolCall,hook 在里面跑。 const resultMessage = executeToolCall(registry, hooks, event.call); state.messages.push(resultMessage); output.log(`tool_result: ${resultMessage.content}`); } else if (event.type === "message_end") { stopReason = event.stopReason; output.log(`message_end: ${stopReason}`); } } if (!sawToolCall || stopReason !== "toolUse") { const assistant: AssistantMessage = { role: "assistant", content, stopReason }; state.messages.push(assistant); return assistant; } } } // ============ Demo Provider(fake)============ // 按传入的工具名发请求,演示 allow 和 block 两种路径。 export class DemoProvider implements Provider { public lastInput: ProviderInput | undefined; constructor(private requestedTool: string) {} async *stream(input: ProviderInput): AsyncGenerator { this.lastInput = input; const last = input.messages[input.messages.length - 1]; yield { type: "message_start" }; if (last?.role === "toolResult") { yield { type: "text_delta", text: `工具结果是:${last.content}` }; yield { type: "message_end", stopReason: "stop" }; return; } yield { type: "tool_call", call: { id: "call_1", name: this.requestedTool, input: { text: "hi" } }, }; yield { type: "message_end", stopReason: "toolUse" }; } } // ============ 演示脚手架 ============ function createRegistry(): ToolRegistry { const registry = new ToolRegistry(); registry.register({ spec: { name: "echo", description: "原样返回输入", input: { text: "要复读的文本" } }, handler: (input) => input.text ?? "(空)", }); registry.register({ spec: { name: "dangerous", description: "一个被禁用的演示工具", input: {} }, handler: () => "不该执行到这里", }); return registry; } function createHooks(output: Output): ToolHooks { return { beforeToolCall(call) { output.log("[beforeToolCall]"); if (call.name === "dangerous") { output.log(`block: ${call.name}`); return { type: "block", reason: "这个工具在演示里被禁用" }; } output.log(`allow: ${call.name}`); return { type: "allow" }; }, afterToolCall(call, result) { output.log("[afterToolCall]"); output.log(`${call.name} -> ${result}`); return `checked: ${result}`; }, }; } function getCase(): "normal" | "blocked" { const index = process.argv.indexOf("--case"); const value = index >= 0 ? process.argv[index + 1] : undefined; return value === "blocked" ? "blocked" : "normal"; } function printAssistantMessage(output: Output, message: AssistantMessage): void { output.log("[assistant]"); output.log(`content: ${message.content}`); output.log(`stopReason: ${message.stopReason}`); output.log(""); } async function main(): Promise { const output = createConsoleOutput(); const state = createInitialState(); const registry = createRegistry(); const hooks = createHooks(output); const caseName = getCase(); const requestedTool = caseName === "blocked" ? "dangerous" : "echo"; const userInput = caseName === "blocked" ? "调用危险工具" : "复读一下 hi"; const provider = new DemoProvider(requestedTool); output.log("s05: Tool Hook Boundary"); output.log(""); output.log("[user]"); output.log(userInput); output.log(""); const assistant = await runEventedToolLoop( state, provider, registry, hooks, userInput, output, ); output.log(""); printAssistantMessage(output, assistant); } main().catch((error: unknown) => { console.error(error); process.exitCode = 1; });