auto-update-daily-20260202
  1import os
  2import asyncio
  3import httpx
  4import json
  5import logging
  6from fastapi import FastAPI, Request, Response
  7from fastapi.responses import StreamingResponse
  8
  9import uvicorn
 10from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
 11
 12# Configurable Ollama host (via env variable or defaults to localhost)
 13OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
 14
 15logging.basicConfig()
 16logger = logging.getLogger(__name__)
 17LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
 18logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
 19
 20app = FastAPI()
 21
 22OLLAMA_CHAT_REQUEST_COUNT = Counter("ollama_requests_total", "Total chat requests", ["model"])
 23
 24OLLAMA_TOTAL_DURATION =       Histogram("ollama_response_seconds", "Total time spent for the response", ["model"])
 25OLLAMA_LOAD_DURATION =        Histogram("ollama_load_duration_seconds", "Time spent loading the model", ["model"])
 26OLLAMA_PROMPT_EVAL_DURATION = Histogram("ollama_prompt_eval_duration_seconds", "Time spent evaluating prompt", ["model"])
 27OLLAMA_EVAL_DURATION =        Histogram("ollama_eval_duration_seconds", "Time spent generating the response", ["model"])
 28
 29OLLAMA_PROMPT_EVAL_COUNT = Counter("ollama_tokens_processed_total", "Number of tokens in the prompt", ["model"])
 30OLLAMA_EVAL_COUNT =        Counter("ollama_tokens_generated_total", "Number of tokens in the response", ["model"])
 31
 32OLLAMA_TOKENS_PER_SECOND = Histogram(
 33    "ollama_tokens_per_second",
 34    "Tokens generated per second",
 35    ["model"],
 36    # Use buckets with suitable ranges for tokens/s measurements
 37    buckets=[5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
 38)
 39
 40
 41def extract_and_record_metrics(response_data, model):
 42    """Extract and record metrics from Ollama response data."""
 43    if not isinstance(response_data, dict):
 44        return
 45
 46    # Support both native Ollama API and OpenAI-compatible v1 API
 47    # Native API has timing data at top level, v1 API might have it in different location
 48    # https://github.com/ollama/ollama/blob/main/docs/api.md#response
 49
 50    # Try to extract from native Ollama format first
 51    total_duration = response_data.get("total_duration", 0) # total time spent in nanoseconds generating the response
 52    load_duration = response_data.get("load_duration", 0) # time spent in nanoseconds loading the model
 53    prompt_eval_duration = response_data.get("prompt_eval_duration", 0) # time spent in nanoseconds evaluating the prompt
 54    prompt_eval_count = response_data.get("prompt_eval_count", 0) # number of tokens in the prompt
 55    eval_duration = response_data.get("eval_duration", 0) # time spent in nanoseconds generating the response
 56    eval_count = response_data.get("eval_count", 0) # number of tokens in the response
 57
 58    # For v1 API, try to extract from usage field if available
 59    usage = response_data.get("usage", {})
 60    if usage and not prompt_eval_count:
 61        prompt_eval_count = usage.get("prompt_tokens", 0)
 62    if usage and not eval_count:
 63        eval_count = usage.get("completion_tokens", 0)
 64
 65    if total_duration > 0:
 66        total_duration_seconds = total_duration / 1_000_000_000
 67        OLLAMA_TOTAL_DURATION.labels(model=model).observe(total_duration_seconds)
 68        logger.debug(f"Model: {model}, Total Duration: {total_duration_seconds:.2f} seconds")
 69    if load_duration > 0:
 70        load_duration_seconds = load_duration / 1_000_000_000
 71        OLLAMA_LOAD_DURATION.labels(model=model).observe(load_duration_seconds)
 72        logger.debug(f"Model: {model}, Load Duration: {load_duration_seconds:.2f} seconds")
 73    if prompt_eval_duration > 0:
 74        prompt_eval_time_seconds = prompt_eval_duration / 1_000_000_000
 75        OLLAMA_PROMPT_EVAL_DURATION.labels(model=model).observe(prompt_eval_time_seconds)
 76        logger.debug(f"Model: {model}, Prompt Eval Duration: {prompt_eval_time_seconds:.2f} seconds")
 77    if prompt_eval_count > 0:
 78        OLLAMA_PROMPT_EVAL_COUNT.labels(model=model).inc(prompt_eval_count)
 79        logger.debug(f"Model: {model}, Prompt Eval Count: {prompt_eval_count}")
 80    if eval_duration > 0:
 81        eval_duration_seconds = eval_duration / 1_000_000_000
 82        OLLAMA_EVAL_DURATION.labels(model=model).observe(eval_duration_seconds)
 83        logger.debug(f"Model: {model}, Eval Duration: {eval_duration_seconds:.2f} seconds")
 84    if eval_count > 0:
 85        OLLAMA_EVAL_COUNT.labels(model=model).inc(eval_count)
 86        logger.debug(f"Model: {model}, Eval Count: {eval_count}")
 87    if eval_duration > 0 and eval_count > 0:
 88        tps = eval_count / eval_duration * 1_000_000_000
 89        OLLAMA_TOKENS_PER_SECOND.labels(model=model).observe(tps)
 90        logger.debug(f"Model: {model}, Tokens per Second: {tps:.2f}")
 91
 92@app.get("/metrics")
 93def metrics():
 94    """Expose Prometheus metrics."""
 95    return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
 96
 97@app.post("/api/chat")
 98@app.post("/api/generate")
 99@app.post("/v1/chat/completions")
100@app.post("/v1/completions")
101async def chat_with_metrics(request: Request):
102    """Handle chat and generate requests with streaming support and metrics extraction."""
103    body = await request.json()
104    model = body.get("model", "unknown")
105    # logger.debug(f"Chat request body: {json.dumps(body, indent=4)}")
106    is_streaming = body.get("stream", False)
107
108    headers = dict(request.headers)
109    headers.pop("host", None)
110    headers.pop("content-length", None)
111    headers.pop("content-type", None)
112
113    OLLAMA_CHAT_REQUEST_COUNT.labels(model=model).inc()
114
115    if is_streaming:
116        async def generate_stream():
117            endpoint = request.url.path  # /api/chat or /api/generate
118            async with httpx.AsyncClient(timeout=httpx.Timeout(900.0, read=900.0)) as client:
119                async with client.stream("POST", f"{OLLAMA_HOST}{endpoint}", headers=headers, json=body, params=request.query_params) as response:
120
121                    final_chunk_data = None
122
123                    async for chunk in response.aiter_bytes():
124                        # Forward the chunk immediately to the client
125                        yield chunk
126
127                        # Try to parse the chunk to look for metrics
128                        if chunk:
129                            try:
130                                chunk_text = chunk.decode('utf-8')
131                                lines = chunk_text.strip().split('\n')
132
133                                for line in lines:
134                                    if line.strip():
135                                        try:
136                                            chunk_json = json.loads(line)
137                                            # Check if this is the final chunk (contains "done": true)
138                                            if chunk_json.get("done", False):
139                                                final_chunk_data = chunk_json
140                                        except json.JSONDecodeError:
141                                            continue
142
143                            except UnicodeDecodeError:
144                                pass
145
146                    # Extract metrics from the final chunk if available
147                    if final_chunk_data:
148                        extract_and_record_metrics(final_chunk_data, model)
149
150        return StreamingResponse(generate_stream(), media_type="application/json")
151    else:
152        endpoint = request.url.path  # /api/chat or /api/generate
153        async with httpx.AsyncClient(timeout=httpx.Timeout(900.0, read=900.0)) as client:
154            response = await client.post(f"{OLLAMA_HOST}{endpoint}", headers=headers, json=body, params=request.query_params)
155
156            if response.status_code == 200:
157                try:
158                    response_data = response.json()
159                    extract_and_record_metrics(response_data, model)
160                except (json.JSONDecodeError, TypeError):
161                    pass
162
163            return Response(content=response.content, status_code=response.status_code, headers=dict(response.headers))
164
165@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
166async def simple_proxy(request: Request, path: str):
167    """Simple pass-through proxy for all other endpoints."""
168    logger.debug(f"Proxying {request.method} request to /{path}")
169    headers = dict(request.headers)
170    headers.pop("host", None)
171    headers.pop("content-length", None)
172
173    async with httpx.AsyncClient(timeout=httpx.Timeout(900.0, read=900.0)) as client:
174        response = await client.request(method=request.method, url=f"{OLLAMA_HOST}/{path}", headers=headers, content=await request.body(), params=request.query_params)
175
176    logger.debug(f"Proxy response: {response.status_code} for {request.method} /{path}")
177    return Response(content=response.content, status_code=response.status_code, headers=dict(response.headers))
178
179async def verify_ollama_connection():
180    """Verify connection to Ollama server at startup."""
181    logger.debug(f"Verifying connection to Ollama server at {OLLAMA_HOST}")
182
183    try:
184        async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
185            response = await client.get(f"{OLLAMA_HOST}/api/version")
186            if response.status_code == 200:
187                logger.info("Connected to Ollama")
188            else:
189                logger.error(f"Failed to connect to Ollama server. Status code: {response.status_code}")
190    except Exception as e:
191        logger.error(f"Failed to connect to Ollama server at {OLLAMA_HOST}: {e}")
192        logger.error("Please ensure Ollama is running and accessible at the configured host")
193
194async def main():
195    await verify_ollama_connection()
196    config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="info")
197    server = uvicorn.Server(config)
198    await server.serve()
199
200if __name__ == "__main__":
201    asyncio.run(main())