2024-08-12 18:23:42 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
import os
|
|
|
|
import vertexai
|
2024-08-13 07:19:37 +02:00
|
|
|
import logging
|
2024-08-12 18:23:42 +02:00
|
|
|
from fastapi import FastAPI
|
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting
|
2024-08-13 07:19:37 +02:00
|
|
|
from cachetools import LRUCache
|
2024-08-12 18:23:42 +02:00
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
2024-08-13 07:19:37 +02:00
|
|
|
|
2024-08-12 18:23:42 +02:00
|
|
|
load_dotenv()
|
|
|
|
PROJECT_ID = os.environ.get('PROJECT_ID')
|
2024-08-13 07:19:37 +02:00
|
|
|
SESSION_CACHE = LRUCache(15)
|
|
|
|
|
|
|
|
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
|
2024-08-12 18:23:42 +02:00
|
|
|
|
|
|
|
|
|
|
|
def generate():
|
|
|
|
vertexai.init(project=PROJECT_ID, location="us-central1")
|
|
|
|
model = GenerativeModel(
|
|
|
|
"gemini-1.5-flash-001",
|
2024-08-29 04:26:15 +02:00
|
|
|
# system_instruction=[textsi_1]
|
2024-08-12 18:23:42 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
return model.start_chat()
|
|
|
|
|
2024-08-29 04:26:15 +02:00
|
|
|
# textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议"""
|
2024-08-12 18:23:42 +02:00
|
|
|
|
|
|
|
generation_config = {
|
|
|
|
"max_output_tokens": 1024,
|
|
|
|
"temperature": 0.2,
|
|
|
|
"top_p": 0,
|
|
|
|
}
|
|
|
|
|
|
|
|
safety_settings = [
|
|
|
|
SafetySetting(
|
|
|
|
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
|
|
|
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
|
|
|
),
|
|
|
|
SafetySetting(
|
|
|
|
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
|
|
|
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
|
|
|
),
|
|
|
|
SafetySetting(
|
|
|
|
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
|
|
|
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
|
|
|
),
|
|
|
|
SafetySetting(
|
|
|
|
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
|
|
|
threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def stream_chat(model : ChatSession, _input):
|
|
|
|
responses = model.send_message(
|
|
|
|
[_input],
|
|
|
|
generation_config=generation_config,
|
|
|
|
safety_settings=safety_settings,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
for chunk in responses:
|
|
|
|
yield chunk.text
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
|
title="AAII",
|
|
|
|
version="0.0.1"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
|
|
prompt: str
|
2024-08-13 07:19:37 +02:00
|
|
|
session_uuid: str
|
2024-08-12 18:23:42 +02:00
|
|
|
|
|
|
|
|
|
|
|
@app.post("/generate_text_stream")
|
2024-08-13 07:19:37 +02:00
|
|
|
async def generate_text(request: ChatRequest):
|
|
|
|
if model := SESSION_CACHE.get(request.session_uuid):
|
|
|
|
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
2024-08-12 18:23:42 +02:00
|
|
|
|
2024-08-13 07:19:37 +02:00
|
|
|
model = generate()
|
|
|
|
|
|
|
|
SESSION_CACHE[request.session_uuid] = model
|
|
|
|
logging.info(f"Buid new session: {request.session_uuid}")
|
|
|
|
return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain")
|
2024-08-12 18:23:42 +02:00
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|