self.persona_history = []
def _initialize_config(self):
"""Private method to load configurations."""
# Load environment variables
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HUGGINGFACE_TOKEN:
raise ValueError("HUGGINGFACE_TOKEN is not set! Make sure to define it in .env.")
with open("config_cannes.yaml", "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
self.parameters = config["parameters"]
self.roles = config["roles"]
self.manifests = config["manifest"]
self.responses = config["responses"]
self.prompts = config["prompts"]
self.rag_queries = config["rag"]
self.instruction = self.select_regional_instruction(self.region)
self.manifest = self.select_regional_manifest(self.region)
self.error_message = config["error"]
# Initialize Hugging Face LLM
self.llm = HuggingFaceAPI(api_url=config["API"]["model_url"], api_token=HUGGINGFACE_TOKEN)
# Initialize RAG Retriever
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.rag_retriever = RagRetrieveWithMeta(
config['rag']['scientific']['faiss'],
config['rag']['scientific']['embeddings'],
config['rag']['diary']['faiss'],
config['rag']['diary']['embeddings'],
config['rag']['weather']['faiss'],
config['rag']['weather']['embeddings'],
config['rag']['insights']['faiss'],
config['rag']['insights']['embeddings']
# Initialize Memory for Chat History
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.persona_history = []
def select_random_region(self):
print("------SELECTING NEW REGION --------")
self.region = "Korea/Hongcheon"
self.region = "Brazil/Viçosa"
self.region = "Czechia/Travný"
print(f"---REGION: {self.region}")
self.instruction = self.select_regional_instruction(self.region)
self.manifest = self.select_regional_manifest(self.region)
def update_region(self,region):
print(f"---REGION: {region}")
def save_history(self, history):
file_path = f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
# Convert history to structured JSON format
"persona_history": self.persona_history, # 🔹 Now includes both persona and the refinement prompt used
"chat_history": [{"--HUMAN--": msg[0], "--TREE--": msg[1]} for msg in history],
"process_logs": self.process_logs
# Save the JSON data to a file
with open(file_path, "w", encoding="utf-8") as file:
json.dump(history_json, file, indent=4)
return file_path # Returns the file path for download
def get_step_inputs(self, step_inputs):
Fetches and formats the required inputs for a step.
- Uses dictionary-style input definitions from `parameters` in `config.yaml`.
- If an input isn't found, it checks `generated_inputs` for dynamically created values.
# Check in parameters first
for category in self.parameters.values(): # `parameters` now has grouped dictionaries
input_data[item] = category[item]
# If not found, check in dynamically generated inputs
if item in self.generated_inputs:
input_data[item] = self.generated_inputs[item]
# ======== FUNCTION to store and return llm outputs ===============
# =================================================================
def store_result_and_return(self, response, _output=None):
# 🛑 Debug: Check if step_output is set correctly
raise ValueError("⚠️ ERROR: `step_output` is missing in store_result_and_return()!")
# ✅ Store response in generated_inputs dictionary
self.generated_inputs[_output] = response
#print(f"\n📝 Stored Result -> `{_output}`:\n{response}\n")
return {_output: response}
def safe_invoke(self, prompt: str) -> str:
response = self.llm.invoke(prompt)
# Log with context about which prompt failed
print(f"Error during LLM invocation for prompt: {prompt}\nError: {e}")
raise RuntimeError("An error occurred while processing the LLM response.") from e
# ======== FUNCTION to format llm response ========================
# =================================================================
# def extract_response(self, response: str):
# # Use regex to find everything AFTER '</think>'
# match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
# # If there's a match, return only the actual response (after '</think>')
# return match.group(1).strip() if match else response
def generate_persona_step(self, _step_index):
_inputs = self.persona_steps[_step_index]["inputs"]
_output = self.persona_steps[_step_index]["output"]
_prompt = self.prompts[self.persona_steps[_step_index]["prompt"]]["template"]
required_inputs = self.get_step_inputs(_inputs)
formatted_prompt = _prompt.format(**required_inputs)
response = self.llm.invoke(formatted_prompt)
clean_response = self.extract_response(response)
self.persona_history.append({
return self.store_result_and_return(clean_response, _output=_output)
def get_refinement_prompt(self):
return self.prompts[self.persona_steps[1]["prompt"]]["template"]
def refine_persona(self):
persona_data = self.generate_persona_step(1)
return persona_data["refined_persona"]
def select_regional_instruction(self, region):
if region == "Czechia/Travný":
print(f"--CORRESPONDENT ROLE: {self.roles['czechia']}")
return self.roles['czechia']
elif region == "Brazil/Viçosa":
print(f"--CORRESPONDENT ROLE: {self.roles['brazil']}")
return self.roles['brazil']
elif region == "Korea/Hongcheon":
print(f"--CORRESPONDENT ROLE: {self.roles['korea']}")
return self.roles['korea']
return self.roles['czechia']
def select_regional_manifest(self, region):
if region == "Czechia/Travný":
return self.manifests['czechia']['template']
elif region == "Brazil/Viçosa":
return self.manifests['brazil']['template']
elif region == "Korea/Hongcheon":
return self.manifests['korea']['template']
return self.manifests['czechia']['template']
def clean_text(self, text):
# Remove occurrences of '---' and any whitespace immediately following it
plain_text = re.sub(r'---\s*', '', text)
# Remove excessive whitespace (including newlines) and trim leading/trailing spaces
clean_text = re.sub(r'\s+', ' ', plain_text).strip()
def format_history_for_prompt(self, chat_history):
if "HumanMessage" in str(type(msg)):
formatted.append(f"Journalist: {msg.content}")
elif "AIMessage" in str(type(msg)):
formatted.append(f"Tree: {msg.content}")
return "\n".join(formatted)
# ===== QUESTION CLASSIFICATION ==========================
# ========================================================
# Function to format response - eliminate reasoning
def extract_response(self,response: str):
# Use regex to find everything AFTER '</think>'
match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
# If there's a match, return only the actual response (after '</think>')
return match.group(1).strip() if match else response
def validate_question(self, question, history, region, role, manifest):
validation_prompt_template = self.prompts['validator']['template']
prompt = validation_prompt_template.format(question=question, history=history, region=region, role=role, manifest=manifest)
validation = self.extract_response(self.llm.invoke(prompt).strip().lower())
#Log the error with context; this avoids throwing the error directly to the UI.
print(f"Error during VALIDATION: {e}")
# Return a safe fallback value. This ensures the UI can continue even if validation fails.
# Function to classify question
def classify_question(self, question, classifier_prompt_template):
# classifier_prompt_template = prompts["hop_classification"]["template"]
prompt = classifier_prompt_template.format(question=question)
classification_raw = self.extract_response(self.llm.invoke(prompt).strip().lower())
classification = self.extract_classification(classification_raw)
results = [c.strip() for c in classification.split(",")]
# Log the error with context; this avoids throwing the error directly to the UI.
print(f"Error during CLASSIFICATION: {e}")
# Return a safe fallback value. This ensures the UI can continue even if validation fails.
def filter_rag_data_by_keywords(self, rag_items, rag_metadata, location_keywords):
Filters the entire RAG data (both items and metadata) based on a list of keywords.
For each category that requires filtering ("scientific" and "weather"):
- For scientific metadata, the location is in meta["source"]["item_site"].
- For weather metadata, the location is in meta["source"][*]["location"] for each dictionary.
An entry is kept if any keyword in location_keywords is found (case-insensitive) in the target field.
rag_items (dict): A dictionary mapping each RAG type (e.g., "scientific", "weather") to a list of actual content items.
rag_metadata (dict): A dictionary mapping each RAG type to a list of metadata dictionaries.
location_keywords (list of str): List of keywords to filter by (e.g., ["korea", "hongcheon"]).
tuple: Two dictionaries, (filtered_items, filtered_metadata), with the same keys as the originals,
where each value is a list of entries passing the location keyword filter.
# Lowercase all keywords for case-insensitive matching.
lower_keywords = [kw.lower() for kw in location_keywords]
# Iterate over each category in rag_items.
for category, items in rag_items.items():
print(f"searching {category}")
# For categories that we're filtering
if category.lower() in ["scientific", "weather"]:
meta_list = rag_metadata.get(category, [])
#print(f"retrieved meta for {category}: {meta_list}")
# Assuming items and meta_list are in corresponding order
for item, meta in zip(items, meta_list):
if category.lower() == "scientific":
# For scientific, check if any keyword appears in source["item_site"]
location_field = meta.get("item_site", "").lower()
#print(f"searching for {lower_keywords}")
if any(kw in location_field for kw in lower_keywords):
filtered_items_list.append(item)
filtered_meta_list.append(meta)
print (f"{passed} items passed SCIENTIFIC")
elif category.lower() == "weather":
# For weather, meta["source"] is a list of dicts; check if any of them contain any keyword
#print(f"--------> Weather sources: {sources}")
if any(kw in src.get("location", "").lower() for src in sources for kw in lower_keywords):
filtered_items_list.append(item)
filtered_meta_list.append(meta)
print(f"{passed} items passed WEATHER")
filtered_items[category] = filtered_items_list
filtered_metadata[category] = filtered_meta_list
# For other categories, no filtering is applied.
filtered_items[category] = items
filtered_metadata[category] = rag_metadata.get(category, [])
return filtered_items, filtered_metadata
def region_key(self,region):
if region == "Czechia/Travný":
elif region == "Brazil/Viçosa":
elif region == "Korea/Hongcheon":
def filter_question_type(self, question):
if "FACTCHECK" in question:
def extract_classification(self, classification_raw):
if "historical" in classification_raw:
classification.append("historical")
if "scientific" in classification_raw:
classification.append("scientific")
if "weather" in classification_raw:
classification.append("weather")
if "insights" in classification_raw:
classification.append("insights")
if "no_rag" in classification_raw:
classification.append("no_rag")
results = ", ".join(classification) #[c.strip() for c in classification.split(",")]
# ===== RAG RETRIEVAL ====================================
# ========================================================
# Function to Handle Multi-Hop RAG Queries
def retrieve_information(self, query, classifications):
if "no_rag" in classifications:
return {},{} # Skip retrieval
retrieved_docs = {} # ✅ Store results by category
for mode in classifications:
print(f"----> current mode: {mode}")
retrieved_texts = self.rag_retriever.rag_wrapper(query, mode=mode)
plain_texts = [doc['content'] for doc in retrieved_texts]
metadata = [doc['metadata'] for doc in retrieved_texts]
query_segment = " ".join(query[:5])
key = query_segment.replace(" ", "_")
retrieved_docs[mode] = plain_texts # ✅ Organized by mode
retrieved_metadata[mode] = metadata
return retrieved_docs, retrieved_metadata
# ===== GENERATE TREE RESPONSE ===========================
# ========================================================
def generate_tree_response(self, question, history):
# init dictionary to hold full llm process logs
# Step 3: Get Chat History for Context
chat_history = self.memory.load_memory_variables({})["chat_history"]
formatted_history = self.format_history_for_prompt(chat_history)
print(f"QUESTION: {question}")
print(f"-----> RESPONDING FOR {region}")
#print (self.instruction + self.prompts['tree_response']['template'])
# Add question to process log
process_log['journalist_question'] = question
print(f"0. Validating question accross region {region}")
# ===============================
# Step 0: VALIDATION ============
# ===============================
validation_result = self.validate_question(question, formatted_history, region, self.instruction, self.manifest)
if validation_result=="invalid":
response_template = self.responses['invalid']['template']
response = response_template.format(region=region)
return history + [(question, response)], ""
if validation_result=="error":
return history + [(question, self.error_message)], ""
# Step 1: Classify Question (Possibly Multi-Hop)
classifications = self.classify_question(question, self.prompts["hop_classification"]["template"])
print(f"1. CLASSIFICATION -- {classifications}")
if classifications == "error":
return history + [(question, self.error_message)], ""
process_log['classification']=classifications
# Step 2: Retrieve Relevant Knowledge (Categorized)
retrieved_info, retrieved_meta = self.retrieve_information(question, classifications)
cleaned_info, cleaned_meta = self.filter_rag_data_by_keywords(retrieved_info,retrieved_meta,
[self.region_key(self.region), "General"])
#print(f"2. RAG RESULTS -- {retrieved_info}")
#print(f"2. RAG METADATA -- {retrieved_meta}")
print("-------------------------------------")
#print(f"2. CLEAN RAG RESULTS -- {cleaned_info}")
#print(f"2. CLEAN RAG METADATA -- {cleaned_meta}")
process_log['retrieved_data'] = cleaned_info
process_log['retrieved_metadata'] = cleaned_meta
flattened_rag = "\n\n".join(
f"[{category.upper()}]\n" + "\n\n".join(docs)
for category, docs in cleaned_info.items()
) if cleaned_info else "No external knowledge needed."
# Step 4: Tree Generates Thoughts Before Answering (Always Challenges & Asks a Question)
tree_thoughts = self.generate_tree_thoughts(self.instruction, question, flattened_rag, formatted_history)
print(f"3. INTERNAL THOUGHTS -- {tree_thoughts}")
process_log['tree_thoughts'] = tree_thoughts
tree_response_prompt_template = self.instruction + self.prompts['tree_response']['template']
# Step 5: Format Final Response Prompt
formatted_prompt = tree_response_prompt_template.format(
chat_history=formatted_history,
retrieved_info= flattened_rag, #"\n".join(retrieved_info) if retrieved_info else "No external knowledge needed.",
tree_thoughts=tree_thoughts
# Step 6: Invoke HuggingFaceAPI for Response
#raw = self.llm.invoke(formatted_prompt)
#print(f"================== RAW DEEPSEEK ANSWER - {raw}")
#print(f"FORMATTED PROMPT {formatted_prompt}")
response = self.extract_response(self.llm.invoke(formatted_prompt).strip())
clean_response = self.clean_text(response)
print(f"4. RAW RESPONSE -- {clean_response}")
process_log['tree_response'] = clean_response
# Step 7: Save chat memory
self.memory.chat_memory.add_user_message(question)
self.memory.chat_memory.add_ai_message(clean_response)
# add process log to history
self.process_logs.append({
f"_process_step_{self.process_step}": process_log
# print(f"TREE: {stylized_response}")
return history + [(question, clean_response)], "" #response
# Log the error details for debugging
print("Error in generate_tree_response at RAG or after:", e)
# Optionally, log more detailed traceback information if needed
return history + [(question, self.error_message)], ""
def fact_check_paragraph(self, paragraph: str, retrieved_info: dict, retrieved_meta: dict) -> str:
Run fact-checking on a paragraph using all retrieved info by classification mode.
# Step 1: Format all RAG results
for classification, docs in retrieved_info.items():
sources.append(f"[Source {idx} - {classification.upper()}]\n{doc}")
# Combine sources for prompt
formatted_rag = "\n\n".join(sources)
# Step 2: Load prompt template
fact_check_template = self.prompts['tree_validate']['template']
formatted_prompt = fact_check_template.format(
retrieved_info=formatted_rag
#print("==== FACTCHECK PROMPT ====\n", formatted_prompt[:3000]) # Trimmed for log readability
response = self.extract_response(self.llm.invoke(formatted_prompt).strip())
source_metadata = {} # e.g., {1: {...}, 2: {...}} from zip()
for classification, metas in retrieved_meta.items():
source_metadata[idx] = meta
enriched_response = self.enrich_factcheck_sources(response, source_metadata)
print("==== FACTCHECK RESPONSE ====\n", enriched_response)
def enrich_factcheck_sources(self, response_text: str, source_metadata: dict) -> str:
Replaces [Source X] in the LLM response with enriched metadata info.
for idx, meta in source_metadata.items():
title = meta.get("item_name", "Unknown Title") #or meta.get("title", "Unknown Title")
author = meta.get("item_author", "Unknown Author") #or meta.get("author", "Unknown Author")
year = meta.get("item_year", "Unknown Year") #or meta.get("year", "Unknown Year")
enriched = f"[Source {idx} – {title}, {author}, {year}]"
# Replace plain [Source X] with enriched version
response_text = re.sub(rf"\[Source {idx}\]", enriched, response_text)
def fact_check_paragraph_meta(self, paragraph: str, retrieved_info: dict, retrieved_meta: dict = None) -> str:
Run fact-checking on a paragraph using all retrieved info and link to source metadata.
for classification, docs in retrieved_info.items():
metadata_list = retrieved_meta.get(classification, []) if retrieved_meta else [{}] * len(docs)
for doc, meta in zip(docs, metadata_list):
title = meta.get('item_name') or meta.get('title', 'Unknown Title')
author = meta.get('item_author') or meta.get('author', 'Unknown Author')
year = meta.get('item_year') or meta.get('year', 'Unknown Year')
meta_info = f"Title: {title}\nAuthor: {author}\nYear: {year}"
meta_info += f"\nURL: {url}"
source_block = f"[Source {idx} - {classification.upper()}]\n{meta_info}\n\n{doc}"
sources.append(source_block)
formatted_rag = "\n\n".join(sources)
print(f"==== FACTCHECK RESPONSE ==== {formatted_rag}")
fact_check_template = self.prompts['tree_validate']['template']
formatted_prompt = fact_check_template.format(
retrieved_info=formatted_rag
#print("==== FACTCHECK PROMPT ====\n", formatted_prompt[:3000]) # Truncate if long
response = self.extract_response(self.llm.invoke(formatted_prompt).strip())
print("==== FACTCHECK RESPONSE ====\n", response)
def generate_tree_thoughts(self,instruction,question, retrieved_info, chat_history):
"""Before answering, the tree decides how to challenge and what question to ask."""
#instruction = self.parameters['general_instruction']
thought_prompt_template = instruction + self.prompts['tree_thought']['template']
formatted_thought_prompt = thought_prompt_template.format(
chat_history=chat_history,
retrieved_info=retrieved_info,
#print(formatted_thought_prompt)
# ✅ Now we correctly invoke the LLM using the full thought template
tree_thoughts = self.extract_response(
self.llm.invoke(formatted_thought_prompt).strip().lower()) # llm.invoke(thought_prompt)
"""Resets the conversation history in memory."""
self.persona_history = []
return [] # Return empty history to reset UI