flake-update-20260201
1import os
2import asyncio
3import httpx
4import json
5import logging
6from fastapi import FastAPI, Request, Response
7from fastapi.responses import StreamingResponse
8
9import uvicorn
10from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
11
12# Configurable Ollama host (via env variable or defaults to localhost)
13OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
14
15logging.basicConfig()
16logger = logging.getLogger(__name__)
17LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
18logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
19
20app = FastAPI()
21
22OLLAMA_CHAT_REQUEST_COUNT = Counter("ollama_requests_total", "Total chat requests", ["model"])
23
24OLLAMA_TOTAL_DURATION = Histogram("ollama_response_seconds", "Total time spent for the response", ["model"])
25OLLAMA_LOAD_DURATION = Histogram("ollama_load_duration_seconds", "Time spent loading the model", ["model"])
26OLLAMA_PROMPT_EVAL_DURATION = Histogram("ollama_prompt_eval_duration_seconds", "Time spent evaluating prompt", ["model"])
27OLLAMA_EVAL_DURATION = Histogram("ollama_eval_duration_seconds", "Time spent generating the response", ["model"])
28
29OLLAMA_PROMPT_EVAL_COUNT = Counter("ollama_tokens_processed_total", "Number of tokens in the prompt", ["model"])
30OLLAMA_EVAL_COUNT = Counter("ollama_tokens_generated_total", "Number of tokens in the response", ["model"])
31
32OLLAMA_TOKENS_PER_SECOND = Histogram(
33 "ollama_tokens_per_second",
34 "Tokens generated per second",
35 ["model"],
36 # Use buckets with suitable ranges for tokens/s measurements
37 buckets=[5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
38)
39
40
41def extract_and_record_metrics(response_data, model):
42 """Extract and record metrics from Ollama response data."""
43 if not isinstance(response_data, dict):
44 return
45
46 # Support both native Ollama API and OpenAI-compatible v1 API
47 # Native API has timing data at top level, v1 API might have it in different location
48 # https://github.com/ollama/ollama/blob/main/docs/api.md#response
49
50 # Try to extract from native Ollama format first
51 total_duration = response_data.get("total_duration", 0) # total time spent in nanoseconds generating the response
52 load_duration = response_data.get("load_duration", 0) # time spent in nanoseconds loading the model
53 prompt_eval_duration = response_data.get("prompt_eval_duration", 0) # time spent in nanoseconds evaluating the prompt
54 prompt_eval_count = response_data.get("prompt_eval_count", 0) # number of tokens in the prompt
55 eval_duration = response_data.get("eval_duration", 0) # time spent in nanoseconds generating the response
56 eval_count = response_data.get("eval_count", 0) # number of tokens in the response
57
58 # For v1 API, try to extract from usage field if available
59 usage = response_data.get("usage", {})
60 if usage and not prompt_eval_count:
61 prompt_eval_count = usage.get("prompt_tokens", 0)
62 if usage and not eval_count:
63 eval_count = usage.get("completion_tokens", 0)
64
65 if total_duration > 0:
66 total_duration_seconds = total_duration / 1_000_000_000
67 OLLAMA_TOTAL_DURATION.labels(model=model).observe(total_duration_seconds)
68 logger.debug(f"Model: {model}, Total Duration: {total_duration_seconds:.2f} seconds")
69 if load_duration > 0:
70 load_duration_seconds = load_duration / 1_000_000_000
71 OLLAMA_LOAD_DURATION.labels(model=model).observe(load_duration_seconds)
72 logger.debug(f"Model: {model}, Load Duration: {load_duration_seconds:.2f} seconds")
73 if prompt_eval_duration > 0:
74 prompt_eval_time_seconds = prompt_eval_duration / 1_000_000_000
75 OLLAMA_PROMPT_EVAL_DURATION.labels(model=model).observe(prompt_eval_time_seconds)
76 logger.debug(f"Model: {model}, Prompt Eval Duration: {prompt_eval_time_seconds:.2f} seconds")
77 if prompt_eval_count > 0:
78 OLLAMA_PROMPT_EVAL_COUNT.labels(model=model).inc(prompt_eval_count)
79 logger.debug(f"Model: {model}, Prompt Eval Count: {prompt_eval_count}")
80 if eval_duration > 0:
81 eval_duration_seconds = eval_duration / 1_000_000_000
82 OLLAMA_EVAL_DURATION.labels(model=model).observe(eval_duration_seconds)
83 logger.debug(f"Model: {model}, Eval Duration: {eval_duration_seconds:.2f} seconds")
84 if eval_count > 0:
85 OLLAMA_EVAL_COUNT.labels(model=model).inc(eval_count)
86 logger.debug(f"Model: {model}, Eval Count: {eval_count}")
87 if eval_duration > 0 and eval_count > 0:
88 tps = eval_count / eval_duration * 1_000_000_000
89 OLLAMA_TOKENS_PER_SECOND.labels(model=model).observe(tps)
90 logger.debug(f"Model: {model}, Tokens per Second: {tps:.2f}")
91
92@app.get("/metrics")
93def metrics():
94 """Expose Prometheus metrics."""
95 return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
96
97@app.post("/api/chat")
98@app.post("/api/generate")
99@app.post("/v1/chat/completions")
100@app.post("/v1/completions")
101async def chat_with_metrics(request: Request):
102 """Handle chat and generate requests with streaming support and metrics extraction."""
103 body = await request.json()
104 model = body.get("model", "unknown")
105 # logger.debug(f"Chat request body: {json.dumps(body, indent=4)}")
106 is_streaming = body.get("stream", False)
107
108 headers = dict(request.headers)
109 headers.pop("host", None)
110 headers.pop("content-length", None)
111 headers.pop("content-type", None)
112
113 OLLAMA_CHAT_REQUEST_COUNT.labels(model=model).inc()
114
115 if is_streaming:
116 async def generate_stream():
117 endpoint = request.url.path # /api/chat or /api/generate
118 async with httpx.AsyncClient(timeout=httpx.Timeout(900.0, read=900.0)) as client:
119 async with client.stream("POST", f"{OLLAMA_HOST}{endpoint}", headers=headers, json=body, params=request.query_params) as response:
120
121 final_chunk_data = None
122
123 async for chunk in response.aiter_bytes():
124 # Forward the chunk immediately to the client
125 yield chunk
126
127 # Try to parse the chunk to look for metrics
128 if chunk:
129 try:
130 chunk_text = chunk.decode('utf-8')
131 lines = chunk_text.strip().split('\n')
132
133 for line in lines:
134 if line.strip():
135 try:
136 chunk_json = json.loads(line)
137 # Check if this is the final chunk (contains "done": true)
138 if chunk_json.get("done", False):
139 final_chunk_data = chunk_json
140 except json.JSONDecodeError:
141 continue
142
143 except UnicodeDecodeError:
144 pass
145
146 # Extract metrics from the final chunk if available
147 if final_chunk_data:
148 extract_and_record_metrics(final_chunk_data, model)
149
150 return StreamingResponse(generate_stream(), media_type="application/json")
151 else:
152 endpoint = request.url.path # /api/chat or /api/generate
153 async with httpx.AsyncClient(timeout=httpx.Timeout(900.0, read=900.0)) as client:
154 response = await client.post(f"{OLLAMA_HOST}{endpoint}", headers=headers, json=body, params=request.query_params)
155
156 if response.status_code == 200:
157 try:
158 response_data = response.json()
159 extract_and_record_metrics(response_data, model)
160 except (json.JSONDecodeError, TypeError):
161 pass
162
163 return Response(content=response.content, status_code=response.status_code, headers=dict(response.headers))
164
165@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
166async def simple_proxy(request: Request, path: str):
167 """Simple pass-through proxy for all other endpoints."""
168 logger.debug(f"Proxying {request.method} request to /{path}")
169 headers = dict(request.headers)
170 headers.pop("host", None)
171 headers.pop("content-length", None)
172
173 async with httpx.AsyncClient(timeout=httpx.Timeout(900.0, read=900.0)) as client:
174 response = await client.request(method=request.method, url=f"{OLLAMA_HOST}/{path}", headers=headers, content=await request.body(), params=request.query_params)
175
176 logger.debug(f"Proxy response: {response.status_code} for {request.method} /{path}")
177 return Response(content=response.content, status_code=response.status_code, headers=dict(response.headers))
178
179async def verify_ollama_connection():
180 """Verify connection to Ollama server at startup."""
181 logger.debug(f"Verifying connection to Ollama server at {OLLAMA_HOST}")
182
183 try:
184 async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
185 response = await client.get(f"{OLLAMA_HOST}/api/version")
186 if response.status_code == 200:
187 logger.info("Connected to Ollama")
188 else:
189 logger.error(f"Failed to connect to Ollama server. Status code: {response.status_code}")
190 except Exception as e:
191 logger.error(f"Failed to connect to Ollama server at {OLLAMA_HOST}: {e}")
192 logger.error("Please ensure Ollama is running and accessible at the configured host")
193
194async def main():
195 await verify_ollama_connection()
196 config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info")
197 server = uvicorn.Server(config)
198 await server.serve()
199
200if __name__ == "__main__":
201 asyncio.run(main())