Commit 5881488af2a0

Vincent Demeester <vincent@sbr.pm>
2026-02-17 14:42:27
feat(pi): optimize tool token budget
Added conditional tool loading extension that filters tools based on workspace context at session start. GitHub tools require a github remote, jira limited to work projects, lsp enabled when project markers exist. Also trimmed github tool description (~75% reduction), removed unused stack_overflow_search tool, and added explicit head parameter to pr-create for worktree workflows. New commands: /load-tool, /tools-all, /tools-reset.
1 parent 4d18a40
Changed files (4)
dots
pi
agent
extensions
dots/pi/agent/extensions/github/actions/pr.ts
@@ -302,14 +302,11 @@ export async function handlePRCreate(
 
 	const args = ["pr", "create", "--title", params.title];
 
-	// Auto-detect --head when not explicitly provided.
-	// This prevents "No commits between main and main" errors when the tool's
-	// cwd (ctx.cwd) is the main repo but work was done in a worktree on a
-	// different branch. We resolve the effective git cwd (worktree-aware) and
-	// read the branch from there. If it differs from the default branch we set
-	// --head so gh targets the correct branch even if the process cwd falls
-	// back to the main repo checkout.
-	if (!params.head) {
+	// Use explicit --head when provided (e.g. "vdemeester:feature-branch"),
+	// otherwise auto-detect from the cwd's git context.
+	if (params.head) {
+		args.push("--head", params.head);
+	} else {
 		try {
 			const gitCwd = await resolveGitCwd(pi, ctx);
 			const branchResult = await pi.exec(
dots/pi/agent/extensions/github/index.ts
@@ -148,22 +148,10 @@ export default function (pi: ExtensionAPI) {
 		label: "GitHub",
 		description:
 			"Manage GitHub PRs, issues, checks, and runs via gh CLI. " +
-			"Actions: pr-list, pr-view, pr-diff, pr-create, pr-checkout, pr-merge, " +
-			"pr-review, pr-comment, pr-line-comment, pr-review-comments, " +
-			"pr-reviews-list, pr-review-edit, pr-review-comments-list, pr-review-comment-edit, pr-review-comment-delete, " +
-			"pr-ready, pr-close, " +
-			"checks, checks-log, checks-restart, run-list, run-view, " +
-			"issue-list, issue-view, issue-create, issue-close, issue-comment, issue-edit, " +
-			"issue-add-sub-issue, issue-remove-sub-issue, " +
-			"repo-view, release-list. " +
-			"Use pr-line-comment to post an inline comment on a specific file/line in the diff. " +
-			"Use pr-review-comments to submit a review (approve/request-changes/comment) with multiple inline comments at once. " +
-			"Use pr-reviews-list/pr-review-edit to list and edit top-level review bodies. " +
-			"Use pr-review-comments-list/pr-review-comment-edit/pr-review-comment-delete to manage inline review comments. " +
-			"Use checks-log with either runId or number (PR number) to get failed check logs - if PR number is provided, the first failed run will be used. " +
-			"Use issue-create with 'parent' param to automatically link as sub-issue. " +
-			"Use issue-add-sub-issue/issue-remove-sub-issue to manage sub-issue relationships on existing issues (number=parent, subIssueNumber=child). " +
-			"Write operations (create, merge, review, comment, close, restart, edit, delete) require user approval.",
+			"Write operations require user approval. " +
+			"checks-log accepts runId or number (PR) — PR auto-selects first failed run. " +
+			"pr-review-comments submits a review with inline comments. " +
+			"issue-create with parent auto-links as sub-issue.",
 
 		parameters: Type.Object({
 			action: StringEnum([
@@ -214,6 +202,7 @@ export default function (pi: ExtensionAPI) {
 			// PR create
 			title: Type.Optional(Type.String({ description: "PR or issue title" })),
 			body: Type.Optional(Type.String({ description: "PR/issue body or comment text" })),
+			head: Type.Optional(Type.String({ description: "Head branch for PR (owner:branch). Overrides auto-detection from cwd." })),
 			draft: Type.Optional(Type.Boolean({ description: "Create as draft PR" })),
 			reviewers: Type.Optional(Type.Array(Type.String(), { description: "PR reviewers to request" })),
 			labels: Type.Optional(Type.Array(Type.String(), { description: "Labels to add" })),
dots/pi/agent/extensions/search/index.ts
@@ -258,78 +258,4 @@ export default function (pi: ExtensionAPI) {
     },
   });
 
-  /**
-   * Stack Overflow Search Tool
-   */
-  pi.registerTool({
-    name: "stack_overflow_search",
-    label: "Stack Overflow Search",
-    description: "Searches for questions on Stack Overflow.",
-    parameters: Type.Object({
-      query: Type.String({
-        description: "The question to search for.",
-      }),
-      limit: Type.Optional(Type.Number({ description: "Maximum number of results (default 10)" })),
-    }),
-    renderCall(args, theme) {
-      let text = theme.fg("toolTitle", theme.bold("stack_overflow_search "));
-      text += theme.fg("muted", `"${args.query}"`);
-      return new Text(text, 0, 0);
-    },
-
-    renderResult(result, { expanded }, theme) {
-      if (result.isError) {
-        return new Text(theme.fg("error", result.content?.[0]?.text || "Search failed"), 0, 0);
-      }
-      const content = result.content?.[0]?.text || "";
-      const count = (content.match(/^\[Score:/gm) || []).length;
-      let text = count > 0
-        ? theme.fg("success", `✓ ${count} results`)
-        : theme.fg("warning", content);
-      if (expanded && content) {
-        text += "\n" + theme.fg("muted", content);
-      }
-      return new Text(text, 0, 0);
-    },
-
-    async execute(_toolCallId, params, signal) {
-      try {
-        const url = `https://api.stackexchange.com/2.3/search?order=desc&sort=relevance&site=stackoverflow&intitle=${encodeURIComponent(params.query)}`;
-        const response = await fetch(url, { signal });
-
-        if (!response.ok) {
-          return {
-            content: [{
-              type: "text",
-              text: `Error: Stack Exchange API returned status ${response.status}`,
-            }],
-          };
-        }
-
-        const data = await response.json();
-        const items = data.items || [];
-
-        if (items.length === 0) {
-          return {
-            content: [{ type: "text", text: "No questions found on Stack Overflow." }],
-          };
-        }
-
-        const maxResults = params.limit ?? 10;
-        const results = items
-          .slice(0, maxResults)
-          .map(
-            (item: any) =>
-              `[Score: ${item.score}] ${item.is_answered ? "(Answered)" : ""} ${item.title}\nURL: ${item.link}`,
-          )
-          .join("\n---\n");
-
-        return { content: [{ type: "text", text: truncate(results, 4000) }] };
-      } catch (e: any) {
-        return {
-          content: [{ type: "text", text: `Stack Overflow search error: ${e.message}` }],
-        };
-      }
-    },
-  });
 }
dots/pi/agent/extensions/tool-filter.ts
@@ -0,0 +1,169 @@
+/**
+ * Pi Extension: Conditional Tool Loading
+ *
+ * Reduces per-request token cost by disabling tools not relevant
+ * to the current workspace. Tools are filtered at session start
+ * based on cwd context (git remotes, project markers, path patterns).
+ *
+ * Commands:
+ *   /load-tool [name]  — Enable a specific tool (no args = list inactive)
+ *   /tools-all         — Enable all registered tools
+ *   /tools-reset       — Re-apply conditional filtering
+ *
+ * Override:
+ *   PI_ALL_TOOLS=1     — Disable filtering entirely
+ */
+
+import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
+import { existsSync } from "fs";
+import { join } from "path";
+
+// Tools that are always active regardless of context
+const ALWAYS_ACTIVE = new Set([
+	// built-ins
+	"read", "bash", "edit", "write",
+	// personal workflow
+	"org_todo", "get_current_time", "web_search",
+	// session management
+	"find_threads", "search_thread",
+	"save_session_to_history", "list_saved_sessions", "read_saved_session",
+	"save_learning", "save_research", "save_plan",
+	// delegation & worktrees
+	"subagent", "git_worktree",
+]);
+
+// Conditional tools: name → async (cwd, exec) => should_enable
+type ExecFn = (cmd: string, args: string[], opts?: any) => Promise<{ code: number; stdout: string; stderr: string }>;
+
+const CONDITIONAL: Record<string, (cwd: string, exec: ExecFn) => Promise<boolean>> = {
+	github: async (cwd, exec) => {
+		const result = await exec("git", ["remote", "-v"], { cwd, timeout: 3000 });
+		return result.code === 0 && result.stdout.includes("github.com");
+	},
+
+	github_search: async (cwd, exec) => {
+		const result = await exec("git", ["remote", "-v"], { cwd, timeout: 3000 });
+		return result.code === 0 && result.stdout.includes("github.com");
+	},
+
+	jira: async (cwd) => {
+		return /\b(tektoncd|osp|redhat|chapeau-rouge)\b/.test(cwd) ||
+			existsSync(join(cwd, ".jira"));
+	},
+
+	lsp: async (cwd) => {
+		// Match root markers from all LSP server configs in lsp-core.ts
+		return [
+			"flake.nix", "default.nix", "shell.nix",               // nix (nil/nixd)
+			"go.mod", "go.work",                                     // go (gopls)
+			"tsconfig.json", "jsconfig.json", "package.json",       // typescript
+			"Cargo.toml",                                            // rust
+			"pyproject.toml", "pyrightconfig.json", "setup.py",     // python
+			"pubspec.yaml",                                          // dart
+			"settings.gradle", "settings.gradle.kts",               // kotlin
+		].some(f => existsSync(join(cwd, f)));
+	},
+
+	// Pending removal — always disabled
+	kitty_control: async () => false,
+};
+
+export default function (pi: ExtensionAPI) {
+	async function applyFilter(cwd: string): Promise<{ all: string[]; kept: string[]; dropped: string[] }> {
+		const all = pi.getActiveTools();
+		const kept: string[] = [];
+
+		for (const name of all) {
+			if (ALWAYS_ACTIVE.has(name)) {
+				kept.push(name);
+				continue;
+			}
+			const rule = CONDITIONAL[name];
+			if (rule) {
+				try {
+					const enabled = await rule(cwd, pi.exec.bind(pi));
+					if (enabled) kept.push(name);
+				} catch {
+					// On error, keep the tool (safe default)
+					kept.push(name);
+				}
+				continue;
+			}
+			// Unknown tools (from other extensions): keep by default
+			kept.push(name);
+		}
+
+		const dropped = all.filter(n => !kept.includes(n));
+		if (dropped.length > 0) {
+			pi.setActiveTools(kept);
+		}
+
+		return { all, kept, dropped };
+	}
+
+	pi.on("session_start", async (_event, ctx) => {
+		if (process.env.PI_ALL_TOOLS === "1") return;
+		await applyFilter(ctx.cwd);
+	});
+
+	pi.registerCommand("load-tool", {
+		description: "Enable a specific tool for this session (no args = list inactive)",
+		getArgumentCompletions: (prefix) => {
+			const all = pi.getAllTools().map(t => t.name);
+			const active = new Set(pi.getActiveTools());
+			return all
+				.filter(n => !active.has(n) && n.startsWith(prefix))
+				.map(n => ({ value: n, description: "Enable tool" }));
+		},
+		handler: async (args, ctx) => {
+			const name = args.trim();
+			if (!name) {
+				const all = pi.getAllTools().map(t => t.name);
+				const active = new Set(pi.getActiveTools());
+				const inactive = all.filter(n => !active.has(n));
+				ctx.ui.notify(
+					inactive.length > 0
+						? `Inactive tools: ${inactive.join(", ")}`
+						: "All tools are active",
+					"info",
+				);
+				return;
+			}
+			const allNames = pi.getAllTools().map(t => t.name);
+			if (!allNames.includes(name)) {
+				ctx.ui.notify(`Unknown tool: ${name}`, "error");
+				return;
+			}
+			const active = pi.getActiveTools();
+			if (active.includes(name)) {
+				ctx.ui.notify(`${name} is already active`, "info");
+				return;
+			}
+			pi.setActiveTools([...active, name]);
+			ctx.ui.notify(`Enabled: ${name}`, "info");
+		},
+	});
+
+	pi.registerCommand("tools-all", {
+		description: "Enable all registered tools for this session",
+		handler: async (_args, ctx) => {
+			const all = pi.getAllTools().map(t => t.name);
+			pi.setActiveTools(all);
+			ctx.ui.notify(`All ${all.length} tools enabled`, "info");
+		},
+	});
+
+	pi.registerCommand("tools-reset", {
+		description: "Re-apply conditional tool filtering",
+		handler: async (_args, ctx) => {
+			// Restore all first, then filter
+			const all = pi.getAllTools().map(t => t.name);
+			pi.setActiveTools(all);
+			const { kept, dropped } = await applyFilter(ctx.cwd);
+			const msg = dropped.length > 0
+				? `Tools: ${kept.length}/${all.length} active (dropped: ${dropped.join(", ")})`
+				: `All ${all.length} tools active (nothing filtered)`;
+			ctx.ui.notify(msg, "info");
+		},
+	});
+}