flake-update-20260505
  1/**
  2 * Thread Search and Reading Extension
  3 *
  4 * Provides find_threads and search_thread tools for searching and reading
  5 * past conversation sessions.
  6 */
  7
  8import * as fs from "node:fs";
  9import * as os from "node:os";
 10import * as path from "node:path";
 11import { Type } from "@sinclair/typebox";
 12import { StringEnum } from "@mariozechner/pi-ai";
 13import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
 14import {
 15	parseSessionEntries,
 16	type FileEntry,
 17	type SessionEntry,
 18	type SessionHeader,
 19} from "@mariozechner/pi-coding-agent";
 20import { Text, Container, Spacer } from "@mariozechner/pi-tui";
 21
 22// ============================================================================
 23// Session Parsing Utilities
 24// ============================================================================
 25
 26function getSessionsDir(): string {
 27	return path.join(os.homedir(), ".pi", "agent", "sessions");
 28}
 29
 30function loadSessionFile(filePath: string): FileEntry[] {
 31	const content = fs.readFileSync(filePath, "utf-8");
 32	return parseSessionEntries(content);
 33}
 34
 35function getSessionHeader(entries: FileEntry[]): SessionHeader | null {
 36	return entries.find((e): e is SessionHeader => e.type === "session") ?? null;
 37}
 38
 39function getSessionEntries(entries: FileEntry[]): SessionEntry[] {
 40	return entries.filter((e): e is SessionEntry => e.type !== "session");
 41}
 42
 43function getLeafEntry(entries: SessionEntry[]): SessionEntry | null {
 44	if (entries.length === 0) return null;
 45
 46	const parentIds = new Set<string>();
 47	for (const entry of entries) {
 48		if ("parentId" in entry && entry.parentId) {
 49			parentIds.add(entry.parentId);
 50		}
 51	}
 52
 53	for (let i = entries.length - 1; i >= 0; i--) {
 54		const entry = entries[i]!;
 55		if (!parentIds.has(entry.id)) {
 56			return entry;
 57		}
 58	}
 59
 60	return entries[entries.length - 1] ?? null;
 61}
 62
 63function getEntryPath(entries: SessionEntry[]): SessionEntry[] {
 64	const leaf = getLeafEntry(entries);
 65	if (!leaf) return [];
 66
 67	const byId = new Map(entries.map((entry) => [entry.id, entry]));
 68	const path: SessionEntry[] = [];
 69	let current: SessionEntry | undefined = leaf;
 70
 71	while (current) {
 72		path.push(current);
 73		const parentId = "parentId" in current ? current.parentId : null;
 74		if (!parentId) break;
 75		current = byId.get(parentId);
 76	}
 77
 78	return path.reverse();
 79}
 80
 81function extractTextContent(content: unknown): string {
 82	if (typeof content === "string") return content;
 83	if (Array.isArray(content)) {
 84		const parts: string[] = [];
 85		for (const part of content) {
 86			if (part.type === "text" && part.text) {
 87				parts.push(part.text);
 88			} else if (part.type === "toolCall" && part.name) {
 89				parts.push(`[Tool: ${part.name}]`);
 90			}
 91		}
 92		return parts.join("\n");
 93	}
 94	return "";
 95}
 96
 97function getFirstUserMessage(entries: SessionEntry[]): string {
 98	for (const entry of entries) {
 99		if (entry.type === "message" && entry.message.role === "user") {
100			const text = extractTextContent(entry.message.content);
101			if (text) return text.slice(0, 200);
102		}
103	}
104	return "(no user message)";
105}
106
107function countMessages(entries: SessionEntry[]): number {
108	return entries.filter((e) => e.type === "message").length;
109}
110
111// ============================================================================
112// Search Functions
113// ============================================================================
114
115async function searchWithGrep(
116	exec: (cmd: string, args: string[], opts?: { timeout?: number }) => Promise<{ stdout: string; stderr: string; code: number }>,
117	query: string,
118	sessionsDir: string,
119	onFallback?: () => void,
120): Promise<Map<string, number>> {
121	const results = new Map<string, number>();
122
123	// Try ripgrep first
124	try {
125		const { stdout, code } = await exec("rg", ["-c", "-i", "--", query, sessionsDir], { timeout: 10000 });
126		if (code === 0 || code === 1) { // 1 = no matches, which is fine
127			for (const line of stdout.split("\n")) {
128				if (!line.trim()) continue;
129				const match = line.match(/^(.+):(\d+)$/);
130				if (match) results.set(match[1], parseInt(match[2], 10));
131			}
132			return results;
133		}
134	} catch {
135		// ripgrep not found or failed, fall back to grep
136	}
137
138	// Fallback to grep
139	onFallback?.();
140	try {
141		const { stdout } = await exec("grep", ["-r", "-c", "-i", query, sessionsDir], { timeout: 30000 });
142		for (const line of stdout.split("\n")) {
143			if (!line.trim()) continue;
144			const match = line.match(/^(.+):(\d+)$/);
145			if (match && parseInt(match[2], 10) > 0) {
146				results.set(match[1], parseInt(match[2], 10));
147			}
148		}
149	} catch {
150		// grep also failed or no matches
151	}
152	return results;
153}
154
155async function getAllSessions(sessionsDir: string): Promise<string[]> {
156	const sessions: string[] = [];
157	if (!fs.existsSync(sessionsDir)) return sessions;
158
159	for (const dirEntry of fs.readdirSync(sessionsDir, { withFileTypes: true })) {
160		if (!dirEntry.isDirectory() || dirEntry.name.startsWith(".")) continue;
161		const dirPath = path.join(sessionsDir, dirEntry.name);
162		for (const fileEntry of fs.readdirSync(dirPath, { withFileTypes: true })) {
163			if (fileEntry.name.endsWith(".jsonl")) {
164				sessions.push(path.join(dirPath, fileEntry.name));
165			}
166		}
167	}
168	return sessions;
169}
170
171// ============================================================================
172// Extension
173// ============================================================================
174
175const FindThreadsParams = Type.Object({
176	query: Type.Optional(Type.String({ description: "Text to search for in messages (uses ripgrep)" })),
177	cwd: Type.Optional(Type.String({ description: "Filter by working directory (partial match)" })),
178	limit: Type.Optional(Type.Number({ description: "Maximum results to return (default: 10)", default: 10 })),
179	sort: Type.Optional(
180		StringEnum(["recent", "oldest", "relevance"] as const, {
181			description: "Sort order: recent (default), oldest, or relevance (by match count)",
182			default: "recent",
183		}),
184	),
185});
186
187const SearchThreadParams = Type.Object({
188	thread_id: Type.String({ description: "Thread ID (session UUID) or file path" }),
189	query: Type.Optional(Type.String({ description: "Search for messages containing this text (case-insensitive). If omitted, returns all messages." })),
190	context: Type.Optional(Type.Number({ description: "Include N messages before/after each match (default: 0)", default: 0 })),
191	roles: Type.Optional(Type.Array(Type.String(), { description: "Filter to specific roles: user, assistant, toolResult (default: all)" })),
192	max_messages: Type.Optional(Type.Number({ description: "Maximum messages to return" })),
193	max_content_length: Type.Optional(Type.Number({ description: "Truncate each message content to N chars" })),
194});
195
196export default function (pi: ExtensionAPI) {
197	// ========================================================================
198	// find_threads tool
199	// ========================================================================
200	pi.registerTool({
201		name: "find_threads",
202		label: "Find Threads",
203		description:
204			"Search through past conversation sessions. Use to find previous discussions, code changes, or decisions. Searches message content using ripgrep for speed.",
205		parameters: FindThreadsParams,
206
207		async execute(_toolCallId, params, _signal, _onUpdate, ctx) {
208			const startTime = Date.now();
209			const sessionsDir = getSessionsDir();
210			const limit = params.limit ?? 10;
211			const sort = params.sort ?? "recent";
212
213			let sessionFiles = await getAllSessions(sessionsDir);
214			let matchCounts: Map<string, number> | null = null;
215
216			// Filter by query using ripgrep (with grep fallback)
217			if (params.query) {
218				matchCounts = await searchWithGrep(
219					pi.exec.bind(pi),
220					params.query,
221					sessionsDir,
222					() => ctx.ui.notify("ripgrep not found, falling back to grep (slower)", "warning"),
223				);
224				sessionFiles = sessionFiles.filter((f) => matchCounts!.has(f));
225			}
226
227			// Filter by cwd
228			if (params.cwd) {
229				const cwdFilter = params.cwd.toLowerCase();
230				sessionFiles = sessionFiles.filter((f) => {
231					const entries = loadSessionFile(f);
232					const header = getSessionHeader(entries);
233					return header?.cwd?.toLowerCase().includes(cwdFilter);
234				});
235			}
236
237			// Parse and build results
238			const results: Array<{
239				id: string;
240				cwd: string;
241				timestamp: string;
242				preview: string;
243				messageCount: number;
244				filePath: string;
245				matchCount?: number;
246			}> = [];
247
248			for (const filePath of sessionFiles) {
249				const entries = loadSessionFile(filePath);
250				const header = getSessionHeader(entries);
251				if (!header) continue;
252
253				const sessionEntries = getSessionEntries(entries);
254				results.push({
255					id: header.id,
256					cwd: header.cwd || "",
257					timestamp: header.timestamp,
258					preview: getFirstUserMessage(sessionEntries),
259					messageCount: countMessages(sessionEntries),
260					filePath,
261					matchCount: matchCounts?.get(filePath),
262				});
263			}
264
265			// Sort
266			if (sort === "recent") {
267				results.sort((a, b) => new Date(b.timestamp).getTime() - new Date(a.timestamp).getTime());
268			} else if (sort === "oldest") {
269				results.sort((a, b) => new Date(a.timestamp).getTime() - new Date(b.timestamp).getTime());
270			} else if (sort === "relevance" && matchCounts) {
271				results.sort((a, b) => (b.matchCount ?? 0) - (a.matchCount ?? 0));
272			}
273
274			const limitedResults = results.slice(0, limit);
275			const searchTime = Date.now() - startTime;
276
277			// Format text output
278			let text = `Found ${results.length} threads`;
279			if (params.query) text += ` matching "${params.query}"`;
280			if (params.cwd) text += ` in ${params.cwd}`;
281			text += ` (${searchTime}ms)\n\n`;
282
283			for (const r of limitedResults) {
284				const date = new Date(r.timestamp).toLocaleDateString();
285				text += `**${r.id}** (${date})\n`;
286				text += `  📁 ${r.cwd}\n`;
287				text += `  💬 ${r.messageCount} messages`;
288				if (r.matchCount) text += ` | ${r.matchCount} matches`;
289				text += `\n  📝 ${r.preview}\n\n`;
290			}
291
292			if (results.length > limit) {
293				text += `... and ${results.length - limit} more. Use limit parameter to see more.`;
294			}
295
296			return {
297				content: [{ type: "text", text }],
298				details: { threads: limitedResults, searchTime, totalSessions: sessionFiles.length },
299			};
300		},
301
302		renderCall(args, theme) {
303			let text = theme.fg("toolTitle", theme.bold("find_threads"));
304			if (args.query) text += " " + theme.fg("accent", `"${args.query}"`);
305			if (args.cwd) text += " " + theme.fg("muted", `in ${args.cwd}`);
306			if (args.limit) text += " " + theme.fg("dim", `limit:${args.limit}`);
307			return new Text(text, 0, 0);
308		},
309
310		renderResult(result, { expanded }, theme) {
311			const { details } = result;
312			if (!details?.threads) {
313				const text = result.content[0];
314				return new Text(text?.type === "text" ? text.text : "(no output)", 0, 0);
315			}
316
317			const { threads, searchTime } = details;
318			const icon = threads.length > 0 ? theme.fg("success", "✓") : theme.fg("muted", "○");
319
320			if (expanded) {
321				const container = new Container();
322				container.addChild(
323					new Text(`${icon} Found ${theme.fg("accent", String(threads.length))} threads (${searchTime}ms)`, 0, 0),
324				);
325
326				for (const t of threads) {
327					container.addChild(new Spacer(1));
328					const date = new Date(t.timestamp).toLocaleDateString();
329					container.addChild(new Text(theme.fg("accent", t.id) + theme.fg("dim", ` (${date})`), 0, 0));
330					container.addChild(new Text(theme.fg("muted", `  📁 ${t.cwd}`), 0, 0));
331					container.addChild(
332						new Text(
333							theme.fg("dim", `  💬 ${t.messageCount} msgs`) +
334								(t.matchCount ? theme.fg("warning", ` | ${t.matchCount} matches`) : ""),
335							0,
336							0,
337						),
338					);
339					const preview = t.preview.length > 80 ? t.preview.slice(0, 80) + "..." : t.preview;
340					container.addChild(new Text(theme.fg("toolOutput", `  ${preview}`), 0, 0));
341				}
342				return container;
343			}
344
345			// Collapsed view
346			let text = `${icon} Found ${theme.fg("accent", String(threads.length))} threads (${searchTime}ms)`;
347			for (const t of threads.slice(0, 3)) {
348				const date = new Date(t.timestamp).toLocaleDateString();
349				const preview = t.preview.length > 50 ? t.preview.slice(0, 50) + "..." : t.preview;
350				text += `\n  ${theme.fg("accent", t.id.slice(0, 8))} ${theme.fg("dim", date)} ${theme.fg("muted", preview)}`;
351			}
352			if (threads.length > 3) {
353				text += `\n  ${theme.fg("muted", `... +${threads.length - 3} more (Ctrl+O to expand)`)}`;
354			}
355			return new Text(text, 0, 0);
356		},
357	});
358
359	// ========================================================================
360	// search_thread tool
361	// ========================================================================
362	pi.registerTool({
363		name: "search_thread",
364		label: "Search Thread",
365		description:
366			"Search and read a specific conversation thread by ID or file path. Returns conversation messages with optional filtering by query text, roles, and context window.",
367		parameters: SearchThreadParams,
368
369		async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
370			const { thread_id, query, context = 0, roles, max_messages, max_content_length } = params;
371			const sessionsDir = getSessionsDir();
372
373			// Find the session file
374			let filePath: string | null = null;
375
376			if (thread_id.endsWith(".jsonl") || thread_id.startsWith("/")) {
377				filePath = thread_id;
378			} else {
379				const allSessions = await getAllSessions(sessionsDir);
380				for (const sessionPath of allSessions) {
381					const entries = loadSessionFile(sessionPath);
382					const header = getSessionHeader(entries);
383					if (header?.id === thread_id) {
384						filePath = sessionPath;
385						break;
386					}
387				}
388			}
389
390			if (!filePath || !fs.existsSync(filePath)) {
391				return {
392					content: [{ type: "text", text: `Thread not found: ${thread_id}` }],
393					details: { thread: null, error: "Thread not found" },
394					isError: true,
395				};
396			}
397
398			const fileEntries = loadSessionFile(filePath);
399			const header = getSessionHeader(fileEntries);
400
401			if (!header) {
402				return {
403					content: [{ type: "text", text: `Invalid session file: ${filePath}` }],
404					details: { thread: null, error: "Invalid session file" },
405					isError: true,
406				};
407			}
408
409			const sessionEntries = getSessionEntries(fileEntries);
410			const branchEntries = getEntryPath(sessionEntries);
411
412			// Build message list
413			const allMessages: Array<{
414				role: string;
415				content: string;
416				timestamp?: string;
417				model?: string;
418				toolName?: string;
419			}> = [];
420			let totalTokens = 0;
421			let totalCost = 0;
422
423			for (const entry of branchEntries) {
424				if (entry.type === "custom_message") {
425					const customEntry = entry as any;
426					const customContent = extractTextContent(customEntry.content);
427					if (!customContent.trim()) continue;
428					allMessages.push({
429						role: customEntry.customType ? `custom:${customEntry.customType}` : "custom",
430						content: customContent.trim(),
431						timestamp: customEntry.timestamp,
432					});
433					continue;
434				}
435
436				if (entry.type !== "message") continue;
437
438				const msg = entry.message;
439				const content = extractTextContent(msg.content);
440				if (!content.trim()) continue;
441
442				allMessages.push({
443					role: msg.role,
444					content: content.trim(),
445					timestamp: entry.timestamp,
446					model: "model" in msg ? (msg as any).model : undefined,
447					toolName: "toolName" in msg ? (msg as any).toolName : undefined,
448				});
449
450				if ("usage" in msg && msg.usage) {
451					const usage = msg.usage as { input?: number; output?: number; cost?: { total?: number } };
452					totalTokens += (usage.input || 0) + (usage.output || 0);
453					totalCost += usage.cost?.total || 0;
454				}
455			}
456
457			// Apply filtering
458			let filteredMessages = allMessages;
459			const originalCount = allMessages.length;
460
461			// 1. Filter by roles if specified
462			if (roles && roles.length > 0) {
463				const rolesLower = roles.map(r => r.toLowerCase());
464				filteredMessages = filteredMessages.filter(msg => 
465					rolesLower.includes(msg.role.toLowerCase())
466				);
467			}
468
469			// 2. Filter by query if specified, with context
470			let matchCount = 0;
471			if (query) {
472				const queryLower = query.toLowerCase();
473				const matchIndices = new Set<number>();
474				
475				// Find matching message indices
476				for (let i = 0; i < filteredMessages.length; i++) {
477					if (filteredMessages[i].content.toLowerCase().includes(queryLower)) {
478						matchCount++;
479						// Add the match and context
480						for (let j = Math.max(0, i - context); j <= Math.min(filteredMessages.length - 1, i + context); j++) {
481							matchIndices.add(j);
482						}
483					}
484				}
485				
486				// Keep only messages in the match set (preserving order)
487				filteredMessages = filteredMessages.filter((_, idx) => matchIndices.has(idx));
488			}
489
490			// 3. Apply max_messages limit (from end)
491			const limitedMessages = max_messages ? filteredMessages.slice(-max_messages) : filteredMessages;
492
493			// 4. Apply content length truncation
494			const finalMessages = max_content_length
495				? limitedMessages.map(msg => ({
496						...msg,
497						content: msg.content.length > max_content_length 
498							? msg.content.slice(0, max_content_length) + "..."
499							: msg.content,
500					}))
501				: limitedMessages;
502
503			const thread = {
504				id: header.id,
505				cwd: header.cwd || "",
506				timestamp: header.timestamp,
507				messages: finalMessages,
508				totalTokens,
509				totalCost,
510			};
511
512			// Format output
513			let text = `## Thread ${thread.id}\n`;
514			text += `**Directory:** ${thread.cwd}\n`;
515			text += `**Started:** ${new Date(thread.timestamp).toLocaleString()}\n`;
516			text += `**Messages:** ${originalCount} total | **Tokens:** ${totalTokens.toLocaleString()} | **Cost:** $${totalCost.toFixed(4)}\n`;
517			
518			// Show filtering info
519			const filters: string[] = [];
520			if (query) filters.push(`query="${query}" (${matchCount} matches)`);
521			if (roles) filters.push(`roles=[${roles.join(", ")}]`);
522			if (context > 0) filters.push(`context=${context}`);
523			if (max_messages) filters.push(`max_messages=${max_messages}`);
524			if (max_content_length) filters.push(`truncate=${max_content_length}`);
525			
526			if (filters.length > 0) {
527				text += `**Filters:** ${filters.join(" | ")}\n`;
528				text += `**Showing:** ${finalMessages.length} of ${originalCount} messages\n`;
529			}
530			
531			text += "\n---\n\n";
532
533			for (const msg of finalMessages) {
534				const roleIcon = msg.role === "user" ? "👤" : msg.role === "assistant" ? "🤖" : "🔧";
535				const roleLabel = msg.role === "toolResult" ? `tool:${msg.toolName}` : msg.role;
536				text += `### ${roleIcon} ${roleLabel}\n`;
537				if (msg.model) text += `*${msg.model}*\n`;
538				text += `\n${msg.content}\n\n`;
539			}
540
541			return {
542				content: [{ type: "text", text }],
543				details: { thread, matchCount, originalCount },
544			};
545		},
546
547		renderCall(args, theme) {
548			let text = theme.fg("toolTitle", theme.bold("search_thread"));
549			text += " " + theme.fg("accent", args.thread_id.slice(0, 36));
550			if (args.query) text += " " + theme.fg("warning", `"${args.query}"`);
551			if (args.roles) text += " " + theme.fg("dim", `roles:[${args.roles.join(",")}]`);
552			if (args.context) text += " " + theme.fg("dim", `ctx:${args.context}`);
553			if (args.max_messages) text += " " + theme.fg("dim", `last:${args.max_messages}`);
554			return new Text(text, 0, 0);
555		},
556
557		renderResult(result, { expanded }, theme) {
558			const { details } = result;
559			if (!details?.thread) {
560				const text = result.content[0];
561				return new Text(
562					text?.type === "text" ? theme.fg("error", text.text) : theme.fg("error", "(error)"),
563					0,
564					0,
565				);
566			}
567
568			const { thread, matchCount, originalCount } = details;
569			const icon = theme.fg("success", "✓");
570			const countInfo = matchCount !== undefined 
571				? `${matchCount} matches, ${thread.messages.length}/${originalCount} shown`
572				: `${thread.messages.length} messages`;
573
574			if (expanded) {
575				const container = new Container();
576				container.addChild(
577					new Text(`${icon} Thread ${theme.fg("accent", thread.id.slice(0, 8))} (${countInfo})`, 0, 0),
578				);
579				container.addChild(new Text(theme.fg("muted", `📁 ${thread.cwd}`), 0, 0));
580				container.addChild(
581					new Text(theme.fg("dim", `📊 ${thread.totalTokens.toLocaleString()} tokens | $${thread.totalCost.toFixed(4)}`), 0, 0),
582				);
583
584				for (const msg of thread.messages) {
585					container.addChild(new Spacer(1));
586					const roleIcon = msg.role === "user" ? "👤" : msg.role === "assistant" ? "🤖" : "🔧";
587					const preview = msg.content.length > 200 ? msg.content.slice(0, 200) + "..." : msg.content;
588					container.addChild(new Text(`${roleIcon} ${theme.fg("accent", msg.role)}`, 0, 0));
589					container.addChild(new Text(theme.fg("toolOutput", preview), 0, 0));
590				}
591				return container;
592			}
593
594			// Collapsed
595			let text = `${icon} Thread ${theme.fg("accent", thread.id.slice(0, 8))} (${countInfo})`;
596			text += `\n  ${theme.fg("muted", thread.cwd)}`;
597			for (const msg of thread.messages.slice(0, 3)) {
598				const preview = msg.content.slice(0, 60).replace(/\n/g, " ");
599				const roleIcon = msg.role === "user" ? "👤" : msg.role === "assistant" ? "🤖" : "🔧";
600				text += `\n  ${roleIcon} ${theme.fg("dim", preview)}${msg.content.length > 60 ? "..." : ""}`;
601			}
602			if (thread.messages.length > 3) {
603				text += `\n  ${theme.fg("muted", `... +${thread.messages.length - 3} more (Ctrl+O to expand)`)}`;
604			}
605			return new Text(text, 0, 0);
606		},
607	});
608}