import json import httpx import os import logging import time import pypdf import docx import io from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form from pydantic import BaseModel, Field from typing import Optional from app.auth import verify_api_key from app.database import get_db from app.utils.stats import track_usage from app.utils.chunking import chunk_text router = APIRouter() logger = logging.getLogger(__name__) LITELLM_URL = os.getenv("LITELLM_PROXY_URL", "http://litellm:4000") LITELLM_MASTER = os.getenv("LITELLM_MASTER_KEY") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "cosair/multilingual-e5-large-instruct") class VectorStoreCreate(BaseModel): name: str metadata: dict = {} class VectorStoreResponse(BaseModel): id: str object: str = "vector_store" name: str metadata: dict = {} created_at: int class FileUploadRequest(BaseModel): texts: list[str] metadata: list[dict] = [] class SearchRequest(BaseModel): query: str top_k: int = Field(default=5, ge=1, le=50) rerank: bool = False rerank_model: Optional[str] = None filters: Optional[dict] = None class EmbeddingRequest(BaseModel): input: str | list[str] model: Optional[str] = None encoding_format: Optional[str] = "float" class RAGRequest(BaseModel): query: str model: str = "cosair/gemma4:31b" top_k: int = 5 rerank: bool = False system_prompt: Optional[str] = None messages: list[dict] = [] # Hilfsfunktionen def is_embedding_model(model: dict) -> bool: """Prueft ob ein Modell ein Embedding Modell ist - nur ueber mode""" mode = ( model.get("mode") or model.get("model_info", {}).get("mode") ) return mode == "embedding" async def _get_all_models() -> list[dict]: """ Alle Modelle mit Master Key holen. Master Key gibt korrekte mode Infos fuer alle Modelle zurueck. """ async with httpx.AsyncClient() as client: try: resp = await client.get( f"{LITELLM_URL}/model_group/info", headers={"Authorization": f"Bearer {LITELLM_MASTER}"}, timeout=10.0 ) except httpx.RequestError as e: raise HTTPException(503, f"LiteLLM nicht erreichbar: {e}") if resp.status_code != 200: raise HTTPException(502, f"Modelle konnten nicht abgerufen werden: {resp.text}") models = [] for m in resp.json().get("data", []): models.append({ **m, "id": m.get("model_group", m.get("id", "")), }) return models async def _embed( text: str, token: str, model: Optional[str] = None ) -> list[float]: """Embedding ueber LiteLLM generieren""" use_model = model or EMBEDDING_MODEL async with httpx.AsyncClient() as client: resp = await client.post( f"{LITELLM_URL}/embeddings", headers={ "Authorization": f"Bearer {token}", "Content-Type": "application/json" }, json={ "model": use_model, "input": text }, timeout=30.0 ) if resp.status_code != 200: logger.error(f"Embedding Fehler: {resp.status_code} - {resp.text}") raise HTTPException(502, f"Embedding fehlgeschlagen: {resp.text}") return resp.json()["data"][0]["embedding"] async def _check_access(db, store_id: str, user_id: str): """Zugriff auf Store pruefen""" row = await db.fetchrow( "SELECT owner_user_id FROM vector_stores WHERE id=$1", store_id ) if not row: raise HTTPException(404, detail={ "error": { "message": f"No vector store found with id '{store_id}'", "type": "invalid_request_error", "code": "not_found" } }) if row["owner_user_id"] != user_id: shared = await db.fetchval( "SELECT 1 FROM store_permissions WHERE store_id=$1 AND user_id=$2", store_id, user_id ) if not shared: raise HTTPException(403, detail={ "error": { "message": "You don't have access to this vector store", "type": "invalid_request_error", "code": "permission_denied" } }) async def _rerank( query: str, results: list[dict], model: str, token: str ) -> list[dict]: """Ergebnisse mit Reranker neu sortieren""" async with httpx.AsyncClient() as client: resp = await client.post( f"{LITELLM_URL}/rerank", headers={"Authorization": f"Bearer {token}"}, json={ "model": model, "query": query, "documents": [r["content"][0]["text"] for r in results] }, timeout=30.0 ) if resp.status_code != 200: logger.error(f"Rerank Fehler: {resp.text}") return results reranked = resp.json()["results"] return [ {**results[r["index"]], "score": r["relevance_score"]} for r in sorted(reranked, key=lambda x: x["relevance_score"], reverse=True) ] # Models Endpoints @router.get("/models") async def list_models( user: dict = Depends(verify_api_key), ): """Alle verfuegbaren Modelle von LiteLLM""" models = await _get_all_models() return { "object": "list", "data": [ { "id": m["id"], "object": "model", "mode": m.get("mode"), "owned_by": "system", } for m in models ] } @router.get("/models/{model_id:path}") async def get_model( model_id: str, user: dict = Depends(verify_api_key), ): """Einzelnes Modell von LiteLLM""" all_models = await _get_all_models() model_lookup = {m["id"]: m for m in all_models} if model_id not in model_lookup: raise HTTPException(404, { "error": { "message": f"Modell '{model_id}' nicht gefunden", "type": "invalid_request_error", "code": "not_found" } }) m = model_lookup[model_id] return { "id": m["id"], "object": "model", "mode": m.get("mode"), "owned_by": "system", } # Embedding Endpoints @router.get("/embeddings/models") async def list_embedding_models( user: dict = Depends(verify_api_key), ): """Nur Embedding Modelle - gefiltert über mode mit Master Key""" all_models = await _get_all_models() embedding_models = [ { "id": m["id"], "object": "model", "owned_by": "system", "default": m["id"] == EMBEDDING_MODEL, } for m in all_models if is_embedding_model(m) ] return { "object": "list", "default": EMBEDDING_MODEL, "data": embedding_models } @router.post("/embeddings") async def create_embeddings( body: EmbeddingRequest, user: dict = Depends(verify_api_key), ): """Embeddings erstellen - einzeln oder als Liste""" start = time.time() model = body.model or EMBEDDING_MODEL inputs = body.input if isinstance(body.input, list) else [body.input] all_models = await _get_all_models() model_lookup = {m["id"]: m for m in all_models} if model in model_lookup and not is_embedding_model(model_lookup[model]): raise HTTPException(400, { "error": { "message": f"'{model}' ist kein Embedding Modell", "type": "invalid_request_error", "code": "invalid_model" } }) embeddings = [] total_tokens = 0 async with httpx.AsyncClient() as client: for i, text in enumerate(inputs): resp = await client.post( f"{LITELLM_URL}/embeddings", headers={ "Authorization": f"Bearer {user['token']}", "Content-Type": "application/json" }, json={"model": model, "input": text}, timeout=30.0 ) if resp.status_code != 200: logger.error(f"Embedding Fehler: {resp.status_code} - {resp.text}") raise HTTPException(502, f"Embedding fehlgeschlagen: {resp.text}") data = resp.json() total_tokens += data.get("usage", {}).get("total_tokens", 0) embeddings.append({ "object": "embedding", "index": i, "embedding": data["data"][0]["embedding"] }) await track_usage( user_id=user["user_id"], action="embed", tokens=total_tokens, duration=time.time() - start ) return { "object": "list", "model": model, "data": embeddings, "usage": { "prompt_tokens": total_tokens, "total_tokens": total_tokens } } # Vector Store Endpoints @router.post("/vector_stores", response_model=VectorStoreResponse) async def create_vector_store( body: VectorStoreCreate, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Neuen Vector Store anlegen""" row = await db.fetchrow( """INSERT INTO vector_stores (name, owner_user_id) VALUES ($1, $2) RETURNING id, name, created_at""", body.name, user["user_id"] ) return VectorStoreResponse( id=str(row["id"]), name=row["name"], metadata=body.metadata, created_at=int(row["created_at"].timestamp()) ) @router.get("/vector_stores") async def list_vector_stores( user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Alle eigenen Vector Stores auflisten""" rows = await db.fetch( """SELECT vs.id, vs.name, vs.created_at, COUNT(d.id) AS file_counts FROM vector_stores vs LEFT JOIN documents d ON d.store_id = vs.id WHERE vs.owner_user_id = $1 GROUP BY vs.id, vs.name, vs.created_at ORDER BY vs.created_at DESC""", user["user_id"] ) return { "object": "list", "data": [ { "id": str(r["id"]), "object": "vector_store", "name": r["name"], "created_at": int(r["created_at"].timestamp()), "file_counts": {"total": r["file_counts"]} } for r in rows ] } @router.get("/vector_stores/{store_id}") async def get_vector_store( store_id: str, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Einzelnen Vector Store abrufen""" await _check_access(db, store_id, user["user_id"]) row = await db.fetchrow( """SELECT vs.id, vs.name, vs.created_at, COUNT(d.id) AS file_counts FROM vector_stores vs LEFT JOIN documents d ON d.store_id = vs.id WHERE vs.id = $1 GROUP BY vs.id, vs.name, vs.created_at""", store_id ) return { "id": str(row["id"]), "object": "vector_store", "name": row["name"], "created_at": int(row["created_at"].timestamp()), "file_counts": {"total": row["file_counts"]} } @router.delete("/vector_stores/{store_id}") async def delete_vector_store( store_id: str, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Vector Store loeschen""" deleted = await db.fetchval( """DELETE FROM vector_stores WHERE id=$1 AND owner_user_id=$2 RETURNING id""", store_id, user["user_id"] ) if not deleted: raise HTTPException(404, "Vector store nicht gefunden") return { "id": store_id, "object": "vector_store.deleted", "deleted": True } # Files Endpoints @router.post("/vector_stores/{store_id}/files") async def add_files( store_id: str, body: FileUploadRequest, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Dokumente in Vector Store einfuegen""" start = time.time() await _check_access(db, store_id, user["user_id"]) ids = [] for i, text in enumerate(body.texts): embedding = await _embed(text, user["token"]) meta = body.metadata[i] if i < len(body.metadata) else {} doc_id = await db.fetchval( """INSERT INTO documents (store_id, content, metadata, embedding) VALUES ($1, $2, $3, $4::vector) RETURNING id""", store_id, text, json.dumps(meta), str(embedding) ) ids.append(str(doc_id)) await track_usage( user_id=user["user_id"], action="upsert", store_id=store_id, duration=time.time() - start ) return { "object": "vector_store.file_batch", "counts": { "completed": len(ids), "failed": 0, "total": len(body.texts) }, "ids": ids } @router.get("/vector_stores/{store_id}/files") async def list_files( store_id: str, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Alle Dokumente eines Vector Stores auflisten""" await _check_access(db, store_id, user["user_id"]) rows = await db.fetch( """SELECT id, content, metadata, created_at FROM documents WHERE store_id=$1 ORDER BY created_at DESC""", store_id ) return { "object": "list", "data": [ { "id": str(r["id"]), "object": "vector_store.file", "content": r["content"][:100] + "...", "metadata": r["metadata"], "created_at": int(r["created_at"].timestamp()) } for r in rows ] } @router.delete("/vector_stores/{store_id}/files/{file_id}") async def delete_file( store_id: str, file_id: str, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Einzelnes Dokument loeschen""" await _check_access(db, store_id, user["user_id"]) deleted = await db.fetchval( "DELETE FROM documents WHERE id=$1 AND store_id=$2 RETURNING id", file_id, store_id ) if not deleted: raise HTTPException(404, "File nicht gefunden") return { "id": file_id, "object": "vector_store.file.deleted", "deleted": True } # Search Endpoint @router.post("/vector_stores/{store_id}/search") async def search( store_id: str, body: SearchRequest, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Aehnliche Dokumente im Vector Store suchen""" start = time.time() await _check_access(db, store_id, user["user_id"]) q_emb = await _embed(body.query, user["token"]) fetch_k = body.top_k * 3 if body.rerank else body.top_k rows = await db.fetch( """SELECT id, content, metadata, 1 - (embedding <=> $1::vector) AS score FROM documents WHERE store_id = $2 ORDER BY embedding <=> $1::vector LIMIT $3""", str(q_emb), store_id, fetch_k ) results = [] for r in rows: metadata = r["metadata"] if isinstance(metadata, str): try: metadata = json.loads(metadata) except Exception: metadata = {} if metadata is None: metadata = {} results.append({ "id": str(r["id"]), "object": "vector_store.search_result", "score": float(r["score"]), "content": [{"type": "text", "text": r["content"]}], "metadata": metadata }) if body.rerank: rerank_model = body.rerank_model or "cosair/bge-reranker-v2-m3" results = await _rerank(body.query, results, rerank_model, user["token"]) results = results[:body.top_k] await track_usage( user_id=user["user_id"], action="search", store_id=store_id, duration=time.time() - start ) return {"object": "list", "data": results} # RAG Endpoint @router.post("/vector_stores/{store_id}/rag") async def rag( store_id: str, body: RAGRequest, user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Retrieval Augmented Generation""" start = time.time() await _check_access(db, store_id, user["user_id"]) q_emb = await _embed(body.query, user["token"]) fetch_k = body.top_k * 3 if body.rerank else body.top_k rows = await db.fetch( """SELECT id, content, metadata, 1 - (embedding <=> $1::vector) AS score FROM documents WHERE store_id = $2 ORDER BY embedding <=> $1::vector LIMIT $3""", str(q_emb), store_id, fetch_k ) results = [ { "id": str(r["id"]), "content": r["content"], "score": float(r["score"]), } for r in rows ] if body.rerank: results = await _rerank( body.query, results, "cosair/bge-reranker-v2-m3", user["token"] ) results = results[:body.top_k] context = "\n\n".join([ f"[{i+1}] {r['content']}" for i, r in enumerate(results) ]) system_prompt = body.system_prompt or ( "Du bist ein hilfreicher Assistent. " "Beantworte Fragen ausschließlich basierend auf dem gegebenen Kontext. " "Wenn die Antwort nicht im Kontext zu finden ist, sage das ehrlich.\n\n" f"Kontext:\n{context}" ) messages = [ {"role": "system", "content": system_prompt}, *body.messages, {"role": "user", "content": body.query} ] async with httpx.AsyncClient() as client: resp = await client.post( f"{LITELLM_URL}/chat/completions", headers={"Authorization": f"Bearer {user['token']}"}, json={"model": body.model, "messages": messages}, timeout=60.0 ) if resp.status_code != 200: raise HTTPException(502, f"LLM Fehler: {resp.text}") llm_data = resp.json() answer = llm_data["choices"][0]["message"]["content"] total_tokens = llm_data.get("usage", {}).get("total_tokens", 0) await track_usage( user_id=user["user_id"], action="rag", store_id=store_id, tokens=total_tokens, duration=time.time() - start ) return { "object": "rag.response", "answer": answer, "sources": [ { "id": r["id"], "content": r["content"][:200] + "...", "score": r["score"] } for r in results ], "model": body.model, "usage": llm_data.get("usage", {}) } @router.post("/vector_stores/{store_id}/upload") async def upload_file( store_id: str, file: UploadFile = File(...), chunk_size: int = Form(default=512), chunk_overlap: int = Form(default=50), user: dict = Depends(verify_api_key), db=Depends(get_db) ): """Datei hochladen, chunken und in Vector Store speichern""" start = time.time() await _check_access(db, store_id, user["user_id"]) content = await file.read() filename = file.filename.lower() try: if filename.endswith(".pdf"): pdf = pypdf.PdfReader(io.BytesIO(content)) text = "\n".join( page.extract_text() for page in pdf.pages if page.extract_text() ) elif filename.endswith(".docx"): doc = docx.Document(io.BytesIO(content)) text = "\n".join( p.text for p in doc.paragraphs if p.text.strip() ) elif filename.endswith(".txt"): text = content.decode("utf-8") elif filename.endswith(".md"): text = content.decode("utf-8") else: raise HTTPException( 400, f"Nicht unterstütztes Format: {file.filename}. " f"Unterstützt: .pdf, .docx, .txt, .md" ) except HTTPException: raise except Exception as e: raise HTTPException(422, f"Datei konnte nicht gelesen werden: {e}") if not text.strip(): raise HTTPException(422, "Datei enthaelt keinen Text") chunks = chunk_text( text=text, chunk_size=chunk_size, overlap=chunk_overlap ) ids = [] failed = 0 for chunk in chunks: try: embedding = await _embed(chunk["text"], user["token"]) doc_id = await db.fetchval( """INSERT INTO documents (store_id, content, metadata, embedding) VALUES ($1, $2, $3, $4::vector) RETURNING id""", store_id, chunk["text"], json.dumps({ "source": file.filename, "chunk": chunk["index"], "start": chunk.get("start", 0), }), str(embedding) ) ids.append(str(doc_id)) except Exception as e: logger.error(f"Chunk {chunk['index']} fehlgeschlagen: {e}") failed += 1 await track_usage( user_id=user["user_id"], action="upload", store_id=store_id, duration=time.time() - start ) return { "object": "vector_store.file_batch", "filename": file.filename, "counts": { "completed": len(ids), "failed": failed, "total": len(chunks) }, "ids": ids }