main
1import {
2 GoogleGenerativeAI,
3 Content,
4 Part,
5 FunctionDeclaration,
6 SchemaType,
7} from "@google/generative-ai";
8import { LLMProvider, Message, ToolDefinition, LLMResponse } from "./types.js";
9import { TSchema, Kind } from "@sinclair/typebox";
10
11export class GoogleProvider implements LLMProvider {
12 name = "google";
13 private client: GoogleGenerativeAI;
14 private model: string;
15
16 constructor(apiKey: string, model: string) {
17 this.client = new GoogleGenerativeAI(apiKey);
18 this.model = model;
19 }
20
21 async chat(
22 messages: Message[],
23 options: {
24 systemPrompt?: string;
25 tools?: ToolDefinition[];
26 maxTokens?: number;
27 }
28 ): Promise<LLMResponse> {
29 const genModel = this.client.getGenerativeModel({
30 model: this.model,
31 systemInstruction: options.systemPrompt,
32 });
33
34 const contents: Content[] = messages
35 .filter((m) => m.role !== "system")
36 .map((m) => ({
37 role: m.role === "assistant" ? "model" : "user",
38 parts: [{ text: m.content }] as Part[],
39 }));
40
41 const tools = options.tools?.map((t) => ({
42 functionDeclarations: [
43 {
44 name: t.name,
45 description: t.description,
46 parameters: typeboxToGeminiSchema(t.parameters),
47 } as FunctionDeclaration,
48 ],
49 }));
50
51 const result = await genModel.generateContent({
52 contents,
53 tools,
54 generationConfig: {
55 maxOutputTokens: options.maxTokens || 4096,
56 },
57 });
58
59 const response = result.response;
60 const candidate = response.candidates?.[0];
61
62 if (!candidate) {
63 return {
64 content: null,
65 toolCalls: [],
66 stopReason: "error",
67 };
68 }
69
70 let textContent: string | null = null;
71 const toolCalls: LLMResponse["toolCalls"] = [];
72
73 for (const part of candidate.content.parts) {
74 if ("text" in part && part.text) {
75 textContent = part.text;
76 } else if ("functionCall" in part && part.functionCall) {
77 toolCalls.push({
78 id: `call_${Date.now()}_${Math.random().toString(36).slice(2)}`,
79 name: part.functionCall.name,
80 arguments: (part.functionCall.args as Record<string, unknown>) || {},
81 });
82 }
83 }
84
85 let stopReason: LLMResponse["stopReason"] = "end_turn";
86 if (toolCalls.length > 0) {
87 stopReason = "tool_use";
88 } else if (candidate.finishReason === "MAX_TOKENS") {
89 stopReason = "max_tokens";
90 }
91
92 return {
93 content: textContent,
94 toolCalls,
95 stopReason,
96 };
97 }
98}
99
100function typeboxToGeminiSchema(
101 schema: TSchema
102): FunctionDeclaration["parameters"] {
103 const result: Record<string, unknown> = {
104 type: getGeminiSchemaType(schema),
105 };
106
107 if (schema.description) {
108 result.description = schema.description;
109 }
110
111 if (schema[Kind] === "Object" && schema.properties) {
112 result.properties = {};
113 const required: string[] = [];
114
115 for (const [key, prop] of Object.entries(
116 schema.properties as Record<string, TSchema>
117 )) {
118 (result.properties as Record<string, unknown>)[key] =
119 typeboxToGeminiSchema(prop);
120 if (prop[Kind] !== "Optional") {
121 required.push(key);
122 }
123 }
124
125 if (required.length > 0) {
126 result.required = required;
127 }
128 }
129
130 if (schema[Kind] === "Array" && schema.items) {
131 result.items = typeboxToGeminiSchema(schema.items as TSchema);
132 }
133
134 return result as unknown as FunctionDeclaration["parameters"];
135}
136
137function getGeminiSchemaType(schema: TSchema): SchemaType {
138 const kind = schema[Kind];
139 switch (kind) {
140 case "String":
141 return SchemaType.STRING;
142 case "Number":
143 case "Integer":
144 return SchemaType.NUMBER;
145 case "Boolean":
146 return SchemaType.BOOLEAN;
147 case "Array":
148 return SchemaType.ARRAY;
149 case "Object":
150 return SchemaType.OBJECT;
151 default:
152 return SchemaType.OBJECT;
153 }
154}