main
  1import OpenAI from "openai";
  2import { LLMProvider, Message, ToolDefinition, LLMResponse } from "./types.js";
  3import { TSchema, Kind } from "@sinclair/typebox";
  4
  5export class OpenAIProvider implements LLMProvider {
  6  name = "openai";
  7  private client: OpenAI;
  8  private model: string;
  9
 10  constructor(apiKey: string, model: string, baseUrl?: string) {
 11    this.client = new OpenAI({
 12      apiKey,
 13      baseURL: baseUrl,
 14    });
 15    this.model = model;
 16  }
 17
 18  async chat(
 19    messages: Message[],
 20    options: {
 21      systemPrompt?: string;
 22      tools?: ToolDefinition[];
 23      maxTokens?: number;
 24    }
 25  ): Promise<LLMResponse> {
 26    const openaiMessages: OpenAI.ChatCompletionMessageParam[] = [];
 27
 28    if (options.systemPrompt) {
 29      openaiMessages.push({
 30        role: "system",
 31        content: options.systemPrompt,
 32      });
 33    }
 34
 35    for (const m of messages) {
 36      if (m.role === "system") continue;
 37      openaiMessages.push({
 38        role: m.role,
 39        content: m.content,
 40      });
 41    }
 42
 43    const tools = options.tools?.map((t) => ({
 44      type: "function" as const,
 45      function: {
 46        name: t.name,
 47        description: t.description,
 48        parameters: typeboxToJsonSchema(t.parameters),
 49      },
 50    }));
 51
 52    const response = await this.client.chat.completions.create({
 53      model: this.model,
 54      messages: openaiMessages,
 55      tools: tools?.length ? tools : undefined,
 56      max_tokens: options.maxTokens || 4096,
 57    });
 58
 59    const choice = response.choices[0];
 60    if (!choice) {
 61      return {
 62        content: null,
 63        toolCalls: [],
 64        stopReason: "error",
 65      };
 66    }
 67
 68    const textContent = choice.message.content;
 69    const toolCalls: LLMResponse["toolCalls"] = [];
 70
 71    if (choice.message.tool_calls) {
 72      for (const tc of choice.message.tool_calls) {
 73        if (tc.type === "function") {
 74          toolCalls.push({
 75            id: tc.id,
 76            name: tc.function.name,
 77            arguments: JSON.parse(tc.function.arguments || "{}"),
 78          });
 79        }
 80      }
 81    }
 82
 83    let stopReason: LLMResponse["stopReason"] = "end_turn";
 84    if (choice.finish_reason === "tool_calls") {
 85      stopReason = "tool_use";
 86    } else if (choice.finish_reason === "length") {
 87      stopReason = "max_tokens";
 88    }
 89
 90    return {
 91      content: textContent,
 92      toolCalls,
 93      stopReason,
 94    };
 95  }
 96}
 97
 98function typeboxToJsonSchema(schema: TSchema): Record<string, unknown> {
 99  const result: Record<string, unknown> = {
100    type: getJsonSchemaType(schema),
101  };
102
103  if (schema.description) {
104    result.description = schema.description;
105  }
106
107  if (schema[Kind] === "Object" && schema.properties) {
108    result.properties = {};
109    const required: string[] = [];
110
111    for (const [key, prop] of Object.entries(
112      schema.properties as Record<string, TSchema>
113    )) {
114      (result.properties as Record<string, unknown>)[key] =
115        typeboxToJsonSchema(prop);
116      if (prop[Kind] !== "Optional") {
117        required.push(key);
118      }
119    }
120
121    if (required.length > 0) {
122      result.required = required;
123    }
124  }
125
126  if (schema[Kind] === "Array" && schema.items) {
127    result.items = typeboxToJsonSchema(schema.items as TSchema);
128  }
129
130  if (schema[Kind] === "Union" && schema.anyOf) {
131    result.anyOf = (schema.anyOf as TSchema[]).map(typeboxToJsonSchema);
132    delete result.type;
133  }
134
135  if (schema.default !== undefined) {
136    result.default = schema.default;
137  }
138
139  return result;
140}
141
142function getJsonSchemaType(schema: TSchema): string {
143  const kind = schema[Kind];
144  switch (kind) {
145    case "String":
146      return "string";
147    case "Number":
148    case "Integer":
149      return "number";
150    case "Boolean":
151      return "boolean";
152    case "Array":
153      return "array";
154    case "Object":
155      return "object";
156    default:
157      return "object";
158  }
159}