#!/usr/bin/env python3 import os import vertexai import logging from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting from cachetools import LRUCache from dotenv import load_dotenv load_dotenv() PROJECT_ID = os.environ.get('PROJECT_ID') SESSION_CACHE = LRUCache(15) logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) def generate(): vertexai.init(project=PROJECT_ID, location="us-central1") model = GenerativeModel( "gemini-1.5-flash-001", # system_instruction=[textsi_1] ) return model.start_chat() # textsi_1 = """你要用中文回答一个德语刚入门的新手的回答,先用德语回答然后在用中文解释这个回答并给出建议""" generation_config = { "max_output_tokens": 1024, "temperature": 0.2, "top_p": 0, } safety_settings = [ SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), ] def stream_chat(model : ChatSession, _input): responses = model.send_message( [_input], generation_config=generation_config, safety_settings=safety_settings, stream=True, ) for chunk in responses: yield chunk.text app = FastAPI( title="AAII", version="0.0.1" ) class ChatRequest(BaseModel): prompt: str session_uuid: str @app.post("/generate_text_stream") async def generate_text(request: ChatRequest): if model := SESSION_CACHE.get(request.session_uuid): return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") model = generate() SESSION_CACHE[request.session_uuid] = model logging.info(f"Buid new session: {request.session_uuid}") return StreamingResponse(stream_chat(model, request.prompt),media_type="text/plain") app.mount("/", StaticFiles(directory="static", html=True), name="static")