114 lines
3.2 KiB
Python
114 lines
3.2 KiB
Python
import json
|
|
import httpx
|
|
import os
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from app.auth import verify_api_key
|
|
from app.database import get_db
|
|
from app.models import UpsertRequest, QueryRequest
|
|
|
|
router = APIRouter()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
LITELLM_URL = os.getenv("LITELLM_PROXY_URL", "http://litellm:4000")
|
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002")
|
|
|
|
|
|
async def _embed(text: str, token: str) -> list[float]:
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.post(
|
|
f"{LITELLM_URL}/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": EMBEDDING_MODEL,
|
|
"input": text
|
|
},
|
|
timeout=30.0
|
|
)
|
|
|
|
if resp.status_code != 200:
|
|
logger.error(f"Embedding Fehler: {resp.status_code} - {resp.text}")
|
|
raise HTTPException(
|
|
502,
|
|
f"Embedding fehlgeschlagen: {resp.status_code} - {resp.text}"
|
|
)
|
|
|
|
return resp.json()["data"][0]["embedding"]
|
|
|
|
|
|
async def _check_access(db, store_id: str, user_id: str):
|
|
row = await db.fetchrow(
|
|
"SELECT owner_user_id FROM vector_stores WHERE id=$1", store_id
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "Store nicht gefunden")
|
|
if row["owner_user_id"] != user_id:
|
|
shared = await db.fetchval(
|
|
"SELECT 1 FROM store_permissions WHERE store_id=$1 AND user_id=$2",
|
|
store_id, user_id
|
|
)
|
|
if not shared:
|
|
raise HTTPException(403, "Kein Zugriff")
|
|
|
|
|
|
@router.post("/upsert")
|
|
async def upsert(
|
|
body: UpsertRequest,
|
|
user: dict = Depends(verify_api_key),
|
|
db=Depends(get_db)
|
|
):
|
|
await _check_access(db, str(body.store_id), user["user_id"])
|
|
|
|
ids = []
|
|
for i, text in enumerate(body.texts):
|
|
embedding = await _embed(text, user["token"])
|
|
meta = body.metadata[i] if i < len(body.metadata) else {}
|
|
|
|
doc_id = await db.fetchval(
|
|
"""INSERT INTO documents (store_id, content, metadata, embedding)
|
|
VALUES ($1, $2, $3, $4::vector) RETURNING id""",
|
|
str(body.store_id),
|
|
text,
|
|
json.dumps(meta),
|
|
str(embedding)
|
|
)
|
|
ids.append(str(doc_id))
|
|
|
|
return {"inserted": len(ids), "ids": ids}
|
|
|
|
|
|
@router.post("/query")
|
|
async def query(
|
|
body: QueryRequest,
|
|
user: dict = Depends(verify_api_key),
|
|
db=Depends(get_db)
|
|
):
|
|
await _check_access(db, str(body.store_id), user["user_id"])
|
|
|
|
q_emb = await _embed(body.query, user["token"])
|
|
|
|
rows = await db.fetch(
|
|
"""SELECT id, content, metadata,
|
|
1 - (embedding <=> $1::vector) AS similarity
|
|
FROM documents
|
|
WHERE store_id = $2
|
|
ORDER BY embedding <=> $1::vector
|
|
LIMIT $3""",
|
|
str(q_emb),
|
|
str(body.store_id),
|
|
body.top_k
|
|
)
|
|
|
|
return {"results": [
|
|
{
|
|
"id": str(r["id"]),
|
|
"content": r["content"],
|
|
"metadata": r["metadata"],
|
|
"similarity": float(r["similarity"])
|
|
}
|
|
for r in rows
|
|
]}
|