from requests_futures.sessions import FuturesSession
from typing import List, Dict, Tuple
from fastapi import FastAPI, Body, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from gradio.themes import Base
from celery_app import celery_app
from llm_chat import ChatManager
from configuration import Config as cf # Contains cf.USERNAME, cf.PASSWORD, cf.GOOGLE_URL, etc.
# Set multiprocessing start method
multiprocessing.set_start_method("spawn", force=True)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Tree Correspondents API")
# OPTIONAL: Add CORS middleware if needed
# --- Health endpoint ---
def check_ports_and_services():
logger.info("🔍 Running startup service checks...")
fastapi_url = os.getenv("FAST_API_URL", "http://localhost:8000")
res = requests.get(f"{fastapi_url}/health", timeout=5)
if res.status_code == 200:
logger.info(f"✅ FastAPI is reachable at {fastapi_url}")
logger.warning(f"⚠️ FastAPI at {fastapi_url} responded with status: {res.status_code}")
logger.error(f"❌ FastAPI check failed: {e}")
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
import redis # local import to avoid conflicts if not needed early
redis_client = redis.Redis.from_url(redis_url)
logger.info(f"✅ Redis is reachable at {redis_url}")
logger.error(f"❌ Redis check failed: {e}")
# Check Gradio port (mostly informational)
port = os.getenv("PORT", 7860)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
result = sock.connect_ex(("0.0.0.0", int(port)))
logger.info(f"✅ Gradio port {port} is open")
logger.warning(f"⚠️ Gradio port {port} seems closed (might not be an issue inside container)")
logger.info("✅ Startup checks completed.\n")
check_ports_and_services()
# --- Load YAML Config ---
with open("config_article.yml", "r", encoding="utf-8") as file:
llm_config = yaml.safe_load(file)
logger.error(f"Error loading YAML config: {e}")
api_info = llm_config["API"]
workflow_steps = llm_config["workflow"]
persona_steps = llm_config["persona_workflow"]
prompts = llm_config["prompts"]
rag_queries = llm_config["rag"]
parameters = llm_config["parameters"]
custom_head = llm_config['HTML']['head']['template']
# --- Authentication Functions ---
def authenticate(user, pwd):
if user == cf.USERNAME and pwd == cf.PASSWORD:
return "Login Successful", True
return "Invalid credentials. Please try again.", False
def check_login(user, pwd):
msg, success = authenticate(user, pwd)
return gr.update(visible=False), gr.update(visible=True), msg
return gr.update(visible=True), gr.update(visible=False), msg
# --- Gradio UI and Interaction Functions ---
chat_manager = state or ChatManager()
chat_manager.load_config()
return gr.update(visible=True), gr.update(visible=False), chat_manager
def record_user_info(email, organization):
url = f"{base_url}?apiKey={cf.GOOGLE_API_KEY}"
"organization": organization,
"time": str(datetime.datetime.now()),
logger.info(f"Recording user info: {payload}")
session = FuturesSession()
future = session.post(url, json=payload)
future.add_done_callback(lambda f: logger.info(f"User info response: {f.result().text}"))
def start_app(email, organization):
if not (email and organization):
if not is_valid_email(email):
def is_valid_email(email):
email_regex = r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)"
return re.match(email_regex, email) is not None
def handle_click(email, organization):
msg = start_app(email, organization)
logger.info(f"Start app message: {msg}")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
record_user_info(email, organization)
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
with open("style.css", "r") as f:
return gr.HTML('<style>' + css + '</style>')
def handle_region_selection(choice, state):
chat_manager = state or ChatManager()
chat_manager.update_region(choice)
return gr.update(value=choice), chat_manager
def handle_send(_input: str, history: List[Tuple[str, str]], state: gr.State):
chat_manager: ChatManager = state or ChatManager()
logger.info(f"Processing input. Session: {id(chat_manager)}, Region: {chat_manager.region}")
f"{cf.FAST_API_URL}/start",
"chat_man": chat_manager.to_dict(),
job_id = res.json().get("job_id")
except requests.RequestException as e:
logger.error(f"Error communicating with FastAPI: {e}")
return "Error communicating with backend.", _input, state
with tqdm(total=100, desc="Processing Task", unit="%", leave=True) as pbar:
status_res = requests.get(f"{cf.FAST_API_URL}/status/{job_id}")
status_data = status_res.json()
state_val = status_data.get("state")
progress_info = status_data.get("progress") or {}
progress = progress_info.get("progress", 0)
if progress > last_progress:
pbar.update(progress - last_progress)
if state_val == "SUCCESS":
_result = status_data.get("result")
_response, _cleared_input = _result
pbar.update(100 - last_progress)
return _response, _cleared_input, chat_manager
# --- Define FastAPI Endpoints ---
chat_man: Dict = Body(...),
chat_input: str = Body(...),
chat_history: List[Tuple[str, str]] = Body(...)
task = generate_tree_response_task.delay(chat_man, chat_input, chat_history)
logger.info(f"Started a new job with id {task.id}")
return {"job_id": task.id}
@app.get("/status/{job_id}")
def get_status(job_id: str):
task = celery_app.AsyncResult(job_id)
if task.state == "PENDING":
return {"state": task.state, "status": "Job is pending..."}
elif task.state == "PROGRESS":
return {"state": task.state, "progress": task.info}
elif task.state == "SUCCESS":
return {"state": task.state, "result": result}
raise HTTPException(status_code=500, detail=str(task.info))
@celery_app.task(bind=True)
def generate_tree_response_task(self, chat_man: Dict, chat_input: str, chat_history: List[Tuple[str, str]]):
fresh_chat_man = ChatManager.from_dict(chat_man)
def progress_callback(state, meta):
self.update_state(state=state, meta=meta)
self.update_state(state='PROGRESS', meta={'step': 'starting generate_tree_response'})
response, cleared_input = fresh_chat_man.generate_tree_response(chat_input, chat_history, progress_callback=progress_callback)
self.update_state(state='PROGRESS', meta={'step': 'completed generate_tree_response', 'progress': 100})
return response, cleared_input
# --- Mount Gradio inside FastAPI ---
# This mounts the Gradio Blocks interface at the URL path /gradio.
gr.mount_gradio_app(app, demo, path="/gradio")
# --- Setup Gradio Interface ---
with gr.Column(visible=False, elem_id="login_page") as login_page:
gr.Markdown("", elem_id="login-spacer")
gr.Markdown("<h3> Please Log In with the credentials you received</h3>", elem_id="login_title")
log_input = gr.Textbox(label="Username", placeholder="Enter your username", container=False)
pwd_input = gr.Textbox(label="Password", placeholder="Enter your password", type="password", container=False)
login_message = gr.Markdown("")
login_button = gr.Button("Log In")
with gr.Column(visible=False) as landing_page:
gr.Markdown("", elem_id="land-spacer")
gr.Markdown("<h3>---Welcome to Tree Correspondents</h3>", elem_id="land-intro")
gr.Markdown("<p>Where the wisdom of the forests meets the power of journalism</p>"
"<p>You've arrived at a newsroom like any other - one where trees speak through science<br>"
"and stories. Across the world, sensors placed in the soil, trunks, and canopies of<br>"
"trees are collecting real-time environmental data.</p>"
gr.Markdown("<h3>---This tool was built for journalists.</h3>", elem_id="land-intro")
gr.Markdown("<p>Where the wisdom of the forests meets the power of journalism</p>"
"<p>You've arrived at a newsroom like any other - one where trees speak through science<br>"
"and stories. Across the world, sensors placed in the soil, trunks, and canopies of<br>"
"trees are collecting real-time environmental data.</p><br>")
email_input = gr.Textbox(show_label=False, container=False, type="email", placeholder="What is your email?", elem_id="my-textbox")
org_input = gr.Textbox(show_label=False, container=False, type="text", placeholder="What press organization do you work for?", elem_id="my-textbox")
with gr.Row(visible=False) as error_message:
floating_message = gr.Markdown(value="<h5>Please **fill in both Email and Organization fields!</h5>", elem_id="floating-message")
with gr.Row(visible=False) as bad_email:
floating_message_2 = gr.Markdown(value="<h5>This is not a valid email format</h5>", elem_id="floating-message-2")
gr.Markdown("<br>Let's write stories the world needs to read.")
gr.Markdown("", elem_id="one-line-spacer")
begin_button = gr.Button("start", elem_id="start-button")
with gr.Column(visible=True) as base_ui:
gr.Markdown("", elem_id="top-spacer")
gr.Markdown("<h1 style='text-align:center;'>🌲 Tree Correspondents 🌲</h1>", elem_id="title")
gr.Markdown("", height=30)
gr.Markdown("Select Region")
site_selection = gr.Dropdown(
["Czechia/Travný", "Brazil/Viçosa", "Korea/Hongcheon"],
allow_custom_value=False,
elem_id="country_dropdown",
chat_button = gr.Button("🌿 Start Conversation 🌿", scale=3)
gr.Markdown("", elem_id="bottom-spacer")
with gr.Column(visible=False) as chat_ui:
gr.Markdown("## 🌲 Tree Correspondents / 🌿 **In Conversation**<br>")
selected_region = gr.Markdown("Brazil/Viçosa")
general_instruction = gr.Textbox(label="Journalist Role", interactive=True, visible=False, value=parameters['general_instruction'], lines=3)
refine_state = gr.State(True)
user_input = gr.Textbox(label="Enter your message", interactive=True, lines=4)
send_button = gr.Button("Send")
reset_button = gr.Button("Reset Chat")
conversation_box = gr.Chatbot(label="Tree Chat", elem_id="tree_chatbox", type="messages")
# Set up Gradio interactions
login_button.click(fn=check_login,
inputs=[log_input, pwd_input],
outputs=[login_page, landing_page, login_message])
begin_button.click(fn=handle_click,
inputs=[email_input, org_input],
outputs=[error_message, bad_email, landing_page, base_ui])
chat_button.click(fn=show_chat_ui,
outputs=[chat_ui, base_ui, chat_state])
send_button.click(fn=handle_send,
inputs=[user_input, conversation_box, chat_state],
outputs=[conversation_box, user_input, chat_state])
reset_button.click(fn=lambda state: ((state or ChatManager()).reset_chat() + (state or ChatManager(),)),
outputs=[conversation_box, chat_state])
site_selection.change(fn=handle_region_selection,
inputs=[site_selection, chat_state],
outputs=[selected_region, chat_state])
user_input.submit(fn=handle_send,
inputs=[user_input, conversation_box, chat_state],
outputs=[conversation_box, user_input, chat_state])
if __name__ == "__main__":
port = int(os.getenv("FAST_API_PORT", 8000))
logger.info(f"Starting FastAPI server on port {port} with Gradio mounted at /gradio")
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)