AAII/app.py

107 lines
2.9 KiB
Python
Raw Normal View History

2024-08-12 18:23:42 +02:00
#!/usr/bin/env python3
import os
import vertexai
2024-08-13 07:19:37 +02:00
import logging
2024-12-23 09:00:03 +01:00
from fastapi import FastAPI, HTTPException
2024-08-12 18:23:42 +02:00
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting
2024-08-13 07:19:37 +02:00
from cachetools import LRUCache
2024-08-12 18:23:42 +02:00
from dotenv import load_dotenv
2024-08-13 07:19:37 +02:00
2024-08-12 18:23:42 +02:00
load_dotenv()
PROJECT_ID = os.environ.get('PROJECT_ID')
2024-08-13 07:19:37 +02:00
SESSION_CACHE = LRUCache(15)
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
2024-08-12 18:23:42 +02:00
2024-12-23 09:00:03 +01:00
def generate(model: str, instruction: str):
vertexai.init(project=PROJECT_ID, location="us-central1")
session = GenerativeModel(
model,
system_instruction=[instruction]
)
2024-08-12 18:23:42 +02:00
2024-12-23 09:00:03 +01:00
return session.start_chat()
2024-08-12 18:23:42 +02:00
generation_config = {
"max_output_tokens": 1024,
"temperature": 0.2,
"top_p": 0,
}
safety_settings = [
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
),
]
def stream_chat(model : ChatSession, _input):
responses = model.send_message(
[_input],
generation_config=generation_config,
safety_settings=safety_settings,
stream=True,
)
for chunk in responses:
yield chunk.text
app = FastAPI(
title="AAII",
version="0.0.1"
)
2024-12-23 09:00:03 +01:00
class SessionRequest(BaseModel):
instruction: str = "你现在是一只喜爱 Lisp 并且喜欢用 Eamcs 编辑的狐狸,\
回答问题时尽量使用符合 Lisp 哲学例如递归等概念进行回答"
session_uuid: str
model: str = 'gemini-1.5-flash-001'
@app.post("/generate_session")
async def generate_session(request: SessionRequest):
if SESSION_CACHE.get(request.session_uuid):
raise HTTPException(500)
model = generate(request.model, request.instruction)
SESSION_CACHE[request.session_uuid] = model
logging.info(f"Buid new session: {request.session_uuid}")
2024-08-12 18:23:42 +02:00
class ChatRequest(BaseModel):
prompt: str
2024-08-13 07:19:37 +02:00
session_uuid: str
2024-08-12 18:23:42 +02:00
@app.post("/generate_text_stream")
2024-08-13 07:19:37 +02:00
async def generate_text(request: ChatRequest):
2024-12-23 09:00:03 +01:00
if session := SESSION_CACHE.get(request.session_uuid):
return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain")
2024-08-12 18:23:42 +02:00
2024-12-23 09:00:03 +01:00
raise HTTPException(404)
2024-08-13 07:19:37 +02:00
2024-08-12 18:23:42 +02:00
app.mount("/", StaticFiles(directory="static", html=True), name="static")