#!/usr/bin/env python3 import os import vertexai import logging from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from vertexai.generative_models import ChatSession, GenerativeModel, Part, SafetySetting from cachetools import LRUCache from dotenv import load_dotenv load_dotenv() PROJECT_ID = os.environ.get('PROJECT_ID') SESSION_CACHE = LRUCache(15) logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG) def generate(model: str, instruction: str | None): vertexai.init(project=PROJECT_ID, location="us-central1") session = GenerativeModel( model, system_instruction=instruction, ) return session.start_chat() generation_config = { "max_output_tokens": 4096, "temperature": 0.2, "top_p": 0, } safety_settings = [ SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE ), ] def stream_chat(model : ChatSession, _input): responses = model.send_message( [_input], generation_config=generation_config, safety_settings=safety_settings, stream=True, ) for chunk in responses: yield chunk.text app = FastAPI( title="AAII", version="0.0.1" ) class SessionRequest(BaseModel): instruction: str | None = None session_uuid: str model: str = 'gemini-1.5-flash-001' @app.post("/generate_session") async def generate_session(request: SessionRequest): if SESSION_CACHE.get(request.session_uuid): raise HTTPException(500) model = generate(request.model, request.instruction) SESSION_CACHE[request.session_uuid] = model logging.info(f"Buid new session: {request.session_uuid}") class ChatRequest(BaseModel): prompt: str session_uuid: str @app.post("/generate_text_stream") async def generate_text(request: ChatRequest): if session := SESSION_CACHE.get(request.session_uuid): return StreamingResponse(stream_chat(session, request.prompt), media_type="text/plain") raise HTTPException(404) app.mount("/", StaticFiles(directory="static", html=True), name="static")