Files
litellm-vector-store/app/routers/openai_compat.py
2026-04-29 08:17:35 +00:00

788 lines
22 KiB
Python

import json
import httpx
import os
import logging
import time
import pypdf
import docx
import io
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
from pydantic import BaseModel, Field
from typing import Optional
from app.auth import verify_api_key
from app.database import get_db
from app.utils.stats import track_usage
from app.utils.chunking import chunk_text
router = APIRouter()
logger = logging.getLogger(__name__)
LITELLM_URL = os.getenv("LITELLM_PROXY_URL", "http://litellm:4000")
LITELLM_MASTER = os.getenv("LITELLM_MASTER_KEY")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "cosair/multilingual-e5-large-instruct")
class VectorStoreCreate(BaseModel):
name: str
metadata: dict = {}
class VectorStoreResponse(BaseModel):
id: str
object: str = "vector_store"
name: str
metadata: dict = {}
created_at: int
class FileUploadRequest(BaseModel):
texts: list[str]
metadata: list[dict] = []
class SearchRequest(BaseModel):
query: str
top_k: int = Field(default=5, ge=1, le=50)
rerank: bool = False
rerank_model: Optional[str] = None
filters: Optional[dict] = None
class EmbeddingRequest(BaseModel):
input: str | list[str]
model: Optional[str] = None
encoding_format: Optional[str] = "float"
class RAGRequest(BaseModel):
query: str
model: str = "cosair/gemma4:31b"
top_k: int = 5
rerank: bool = False
system_prompt: Optional[str] = None
messages: list[dict] = []
# Hilfsfunktionen
def is_embedding_model(model: dict) -> bool:
"""Prueft ob ein Modell ein Embedding Modell ist - nur ueber mode"""
mode = (
model.get("mode") or
model.get("model_info", {}).get("mode")
)
return mode == "embedding"
async def _get_all_models() -> list[dict]:
"""
Alle Modelle mit Master Key holen.
Master Key gibt korrekte mode Infos fuer alle Modelle zurueck.
"""
async with httpx.AsyncClient() as client:
try:
resp = await client.get(
f"{LITELLM_URL}/model_group/info",
headers={"Authorization": f"Bearer {LITELLM_MASTER}"},
timeout=10.0
)
except httpx.RequestError as e:
raise HTTPException(503, f"LiteLLM nicht erreichbar: {e}")
if resp.status_code != 200:
raise HTTPException(502, f"Modelle konnten nicht abgerufen werden: {resp.text}")
models = []
for m in resp.json().get("data", []):
models.append({
**m,
"id": m.get("model_group", m.get("id", "")),
})
return models
async def _embed(
text: str,
token: str,
model: Optional[str] = None
) -> list[float]:
"""Embedding ueber LiteLLM generieren"""
use_model = model or EMBEDDING_MODEL
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{LITELLM_URL}/embeddings",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
},
json={
"model": use_model,
"input": text
},
timeout=30.0
)
if resp.status_code != 200:
logger.error(f"Embedding Fehler: {resp.status_code} - {resp.text}")
raise HTTPException(502, f"Embedding fehlgeschlagen: {resp.text}")
return resp.json()["data"][0]["embedding"]
async def _check_access(db, store_id: str, user_id: str):
"""Zugriff auf Store pruefen"""
row = await db.fetchrow(
"SELECT owner_user_id FROM vector_stores WHERE id=$1", store_id
)
if not row:
raise HTTPException(404, detail={
"error": {
"message": f"No vector store found with id '{store_id}'",
"type": "invalid_request_error",
"code": "not_found"
}
})
if row["owner_user_id"] != user_id:
shared = await db.fetchval(
"SELECT 1 FROM store_permissions WHERE store_id=$1 AND user_id=$2",
store_id, user_id
)
if not shared:
raise HTTPException(403, detail={
"error": {
"message": "You don't have access to this vector store",
"type": "invalid_request_error",
"code": "permission_denied"
}
})
async def _rerank(
query: str,
results: list[dict],
model: str,
token: str
) -> list[dict]:
"""Ergebnisse mit Reranker neu sortieren"""
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{LITELLM_URL}/rerank",
headers={"Authorization": f"Bearer {token}"},
json={
"model": model,
"query": query,
"documents": [r["content"][0]["text"] for r in results]
},
timeout=30.0
)
if resp.status_code != 200:
logger.error(f"Rerank Fehler: {resp.text}")
return results
reranked = resp.json()["results"]
return [
{**results[r["index"]], "score": r["relevance_score"]}
for r in sorted(reranked, key=lambda x: x["relevance_score"], reverse=True)
]
# Models Endpoints
@router.get("/models")
async def list_models(
user: dict = Depends(verify_api_key),
):
"""Alle verfuegbaren Modelle von LiteLLM"""
models = await _get_all_models()
return {
"object": "list",
"data": [
{
"id": m["id"],
"object": "model",
"mode": m.get("mode"),
"owned_by": "system",
}
for m in models
]
}
@router.get("/models/{model_id:path}")
async def get_model(
model_id: str,
user: dict = Depends(verify_api_key),
):
"""Einzelnes Modell von LiteLLM"""
all_models = await _get_all_models()
model_lookup = {m["id"]: m for m in all_models}
if model_id not in model_lookup:
raise HTTPException(404, {
"error": {
"message": f"Modell '{model_id}' nicht gefunden",
"type": "invalid_request_error",
"code": "not_found"
}
})
m = model_lookup[model_id]
return {
"id": m["id"],
"object": "model",
"mode": m.get("mode"),
"owned_by": "system",
}
# Embedding Endpoints
@router.get("/embeddings/models")
async def list_embedding_models(
user: dict = Depends(verify_api_key),
):
"""Nur Embedding Modelle - gefiltert über mode mit Master Key"""
all_models = await _get_all_models()
embedding_models = [
{
"id": m["id"],
"object": "model",
"owned_by": "system",
"default": m["id"] == EMBEDDING_MODEL,
}
for m in all_models
if is_embedding_model(m)
]
return {
"object": "list",
"default": EMBEDDING_MODEL,
"data": embedding_models
}
@router.post("/embeddings")
async def create_embeddings(
body: EmbeddingRequest,
user: dict = Depends(verify_api_key),
):
"""Embeddings erstellen - einzeln oder als Liste"""
start = time.time()
model = body.model or EMBEDDING_MODEL
inputs = body.input if isinstance(body.input, list) else [body.input]
all_models = await _get_all_models()
model_lookup = {m["id"]: m for m in all_models}
if model in model_lookup and not is_embedding_model(model_lookup[model]):
raise HTTPException(400, {
"error": {
"message": f"'{model}' ist kein Embedding Modell",
"type": "invalid_request_error",
"code": "invalid_model"
}
})
embeddings = []
total_tokens = 0
async with httpx.AsyncClient() as client:
for i, text in enumerate(inputs):
resp = await client.post(
f"{LITELLM_URL}/embeddings",
headers={
"Authorization": f"Bearer {user['token']}",
"Content-Type": "application/json"
},
json={"model": model, "input": text},
timeout=30.0
)
if resp.status_code != 200:
logger.error(f"Embedding Fehler: {resp.status_code} - {resp.text}")
raise HTTPException(502, f"Embedding fehlgeschlagen: {resp.text}")
data = resp.json()
total_tokens += data.get("usage", {}).get("total_tokens", 0)
embeddings.append({
"object": "embedding",
"index": i,
"embedding": data["data"][0]["embedding"]
})
await track_usage(
user_id=user["user_id"],
action="embed",
tokens=total_tokens,
duration=time.time() - start
)
return {
"object": "list",
"model": model,
"data": embeddings,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens
}
}
# Vector Store Endpoints
@router.post("/vector_stores", response_model=VectorStoreResponse)
async def create_vector_store(
body: VectorStoreCreate,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Neuen Vector Store anlegen"""
row = await db.fetchrow(
"""INSERT INTO vector_stores (name, owner_user_id)
VALUES ($1, $2)
RETURNING id, name, created_at""",
body.name, user["user_id"]
)
return VectorStoreResponse(
id=str(row["id"]),
name=row["name"],
metadata=body.metadata,
created_at=int(row["created_at"].timestamp())
)
@router.get("/vector_stores")
async def list_vector_stores(
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Alle eigenen Vector Stores auflisten"""
rows = await db.fetch(
"""SELECT vs.id, vs.name, vs.created_at,
COUNT(d.id) AS file_counts
FROM vector_stores vs
LEFT JOIN documents d ON d.store_id = vs.id
WHERE vs.owner_user_id = $1
GROUP BY vs.id, vs.name, vs.created_at
ORDER BY vs.created_at DESC""",
user["user_id"]
)
return {
"object": "list",
"data": [
{
"id": str(r["id"]),
"object": "vector_store",
"name": r["name"],
"created_at": int(r["created_at"].timestamp()),
"file_counts": {"total": r["file_counts"]}
}
for r in rows
]
}
@router.get("/vector_stores/{store_id}")
async def get_vector_store(
store_id: str,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Einzelnen Vector Store abrufen"""
await _check_access(db, store_id, user["user_id"])
row = await db.fetchrow(
"""SELECT vs.id, vs.name, vs.created_at,
COUNT(d.id) AS file_counts
FROM vector_stores vs
LEFT JOIN documents d ON d.store_id = vs.id
WHERE vs.id = $1
GROUP BY vs.id, vs.name, vs.created_at""",
store_id
)
return {
"id": str(row["id"]),
"object": "vector_store",
"name": row["name"],
"created_at": int(row["created_at"].timestamp()),
"file_counts": {"total": row["file_counts"]}
}
@router.delete("/vector_stores/{store_id}")
async def delete_vector_store(
store_id: str,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Vector Store loeschen"""
deleted = await db.fetchval(
"""DELETE FROM vector_stores
WHERE id=$1 AND owner_user_id=$2
RETURNING id""",
store_id, user["user_id"]
)
if not deleted:
raise HTTPException(404, "Vector store nicht gefunden")
return {
"id": store_id,
"object": "vector_store.deleted",
"deleted": True
}
# Files Endpoints
@router.post("/vector_stores/{store_id}/files")
async def add_files(
store_id: str,
body: FileUploadRequest,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Dokumente in Vector Store einfuegen"""
start = time.time()
await _check_access(db, store_id, user["user_id"])
ids = []
for i, text in enumerate(body.texts):
embedding = await _embed(text, user["token"])
meta = body.metadata[i] if i < len(body.metadata) else {}
doc_id = await db.fetchval(
"""INSERT INTO documents (store_id, content, metadata, embedding)
VALUES ($1, $2, $3, $4::vector) RETURNING id""",
store_id, text, json.dumps(meta), str(embedding)
)
ids.append(str(doc_id))
await track_usage(
user_id=user["user_id"],
action="upsert",
store_id=store_id,
duration=time.time() - start
)
return {
"object": "vector_store.file_batch",
"counts": {
"completed": len(ids),
"failed": 0,
"total": len(body.texts)
},
"ids": ids
}
@router.get("/vector_stores/{store_id}/files")
async def list_files(
store_id: str,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Alle Dokumente eines Vector Stores auflisten"""
await _check_access(db, store_id, user["user_id"])
rows = await db.fetch(
"""SELECT id, content, metadata, created_at
FROM documents
WHERE store_id=$1
ORDER BY created_at DESC""",
store_id
)
return {
"object": "list",
"data": [
{
"id": str(r["id"]),
"object": "vector_store.file",
"content": r["content"][:100] + "...",
"metadata": r["metadata"],
"created_at": int(r["created_at"].timestamp())
}
for r in rows
]
}
@router.delete("/vector_stores/{store_id}/files/{file_id}")
async def delete_file(
store_id: str,
file_id: str,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Einzelnes Dokument loeschen"""
await _check_access(db, store_id, user["user_id"])
deleted = await db.fetchval(
"DELETE FROM documents WHERE id=$1 AND store_id=$2 RETURNING id",
file_id, store_id
)
if not deleted:
raise HTTPException(404, "File nicht gefunden")
return {
"id": file_id,
"object": "vector_store.file.deleted",
"deleted": True
}
# Search Endpoint
@router.post("/vector_stores/{store_id}/search")
async def search(
store_id: str,
body: SearchRequest,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Aehnliche Dokumente im Vector Store suchen"""
start = time.time()
await _check_access(db, store_id, user["user_id"])
q_emb = await _embed(body.query, user["token"])
fetch_k = body.top_k * 3 if body.rerank else body.top_k
rows = await db.fetch(
"""SELECT id, content, metadata,
1 - (embedding <=> $1::vector) AS score
FROM documents
WHERE store_id = $2
ORDER BY embedding <=> $1::vector
LIMIT $3""",
str(q_emb), store_id, fetch_k
)
results = []
for r in rows:
metadata = r["metadata"]
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except Exception:
metadata = {}
if metadata is None:
metadata = {}
results.append({
"id": str(r["id"]),
"object": "vector_store.search_result",
"score": float(r["score"]),
"content": [{"type": "text", "text": r["content"]}],
"metadata": metadata
})
if body.rerank:
rerank_model = body.rerank_model or "cosair/bge-reranker-v2-m3"
results = await _rerank(body.query, results, rerank_model, user["token"])
results = results[:body.top_k]
await track_usage(
user_id=user["user_id"],
action="search",
store_id=store_id,
duration=time.time() - start
)
return {"object": "list", "data": results}
# RAG Endpoint
@router.post("/vector_stores/{store_id}/rag")
async def rag(
store_id: str,
body: RAGRequest,
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Retrieval Augmented Generation"""
start = time.time()
await _check_access(db, store_id, user["user_id"])
q_emb = await _embed(body.query, user["token"])
fetch_k = body.top_k * 3 if body.rerank else body.top_k
rows = await db.fetch(
"""SELECT id, content, metadata,
1 - (embedding <=> $1::vector) AS score
FROM documents
WHERE store_id = $2
ORDER BY embedding <=> $1::vector
LIMIT $3""",
str(q_emb), store_id, fetch_k
)
results = [
{
"id": str(r["id"]),
"content": r["content"],
"score": float(r["score"]),
}
for r in rows
]
if body.rerank:
results = await _rerank(
body.query, results,
"cosair/bge-reranker-v2-m3",
user["token"]
)
results = results[:body.top_k]
context = "\n\n".join([
f"[{i+1}] {r['content']}"
for i, r in enumerate(results)
])
system_prompt = body.system_prompt or (
"Du bist ein hilfreicher Assistent. "
"Beantworte Fragen ausschließlich basierend auf dem gegebenen Kontext. "
"Wenn die Antwort nicht im Kontext zu finden ist, sage das ehrlich.\n\n"
f"Kontext:\n{context}"
)
messages = [
{"role": "system", "content": system_prompt},
*body.messages,
{"role": "user", "content": body.query}
]
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{LITELLM_URL}/chat/completions",
headers={"Authorization": f"Bearer {user['token']}"},
json={"model": body.model, "messages": messages},
timeout=60.0
)
if resp.status_code != 200:
raise HTTPException(502, f"LLM Fehler: {resp.text}")
llm_data = resp.json()
answer = llm_data["choices"][0]["message"]["content"]
total_tokens = llm_data.get("usage", {}).get("total_tokens", 0)
await track_usage(
user_id=user["user_id"],
action="rag",
store_id=store_id,
tokens=total_tokens,
duration=time.time() - start
)
return {
"object": "rag.response",
"answer": answer,
"sources": [
{
"id": r["id"],
"content": r["content"][:200] + "...",
"score": r["score"]
}
for r in results
],
"model": body.model,
"usage": llm_data.get("usage", {})
}
@router.post("/vector_stores/{store_id}/upload")
async def upload_file(
store_id: str,
file: UploadFile = File(...),
chunk_size: int = Form(default=512),
chunk_overlap: int = Form(default=50),
user: dict = Depends(verify_api_key),
db=Depends(get_db)
):
"""Datei hochladen, chunken und in Vector Store speichern"""
start = time.time()
await _check_access(db, store_id, user["user_id"])
content = await file.read()
filename = file.filename.lower()
try:
if filename.endswith(".pdf"):
pdf = pypdf.PdfReader(io.BytesIO(content))
text = "\n".join(
page.extract_text()
for page in pdf.pages
if page.extract_text()
)
elif filename.endswith(".docx"):
doc = docx.Document(io.BytesIO(content))
text = "\n".join(
p.text for p in doc.paragraphs if p.text.strip()
)
elif filename.endswith(".txt"):
text = content.decode("utf-8")
elif filename.endswith(".md"):
text = content.decode("utf-8")
else:
raise HTTPException(
400,
f"Nicht unterstütztes Format: {file.filename}. "
f"Unterstützt: .pdf, .docx, .txt, .md"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(422, f"Datei konnte nicht gelesen werden: {e}")
if not text.strip():
raise HTTPException(422, "Datei enthaelt keinen Text")
chunks = chunk_text(
text=text,
chunk_size=chunk_size,
overlap=chunk_overlap
)
ids = []
failed = 0
for chunk in chunks:
try:
embedding = await _embed(chunk["text"], user["token"])
doc_id = await db.fetchval(
"""INSERT INTO documents (store_id, content, metadata, embedding)
VALUES ($1, $2, $3, $4::vector) RETURNING id""",
store_id,
chunk["text"],
json.dumps({
"source": file.filename,
"chunk": chunk["index"],
"start": chunk.get("start", 0),
}),
str(embedding)
)
ids.append(str(doc_id))
except Exception as e:
logger.error(f"Chunk {chunk['index']} fehlgeschlagen: {e}")
failed += 1
await track_usage(
user_id=user["user_id"],
action="upload",
store_id=store_id,
duration=time.time() - start
)
return {
"object": "vector_store.file_batch",
"filename": file.filename,
"counts": {
"completed": len(ids),
"failed": failed,
"total": len(chunks)
},
"ids": ids
}