self.persona_history = []
def _initialize_config(self):
"""Private method to load configurations."""
# Load environment variables
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HUGGINGFACE_TOKEN:
raise ValueError("HUGGINGFACE_TOKEN is not set! Make sure to define it in .env.")
with open("config_cannes.yaml", "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
self.parameters = config["parameters"]
self.roles = config["roles"]
self.manifests = config["manifest"]
self.responses = config["responses"]
self.prompts = config["prompts"]
self.rag_queries = config["rag"]
self.instruction = self.select_regional_instruction(self.region)
self.manifest = self.select_regional_manifest(self.region)
# Initialize Hugging Face LLM
self.llm = HuggingFaceAPI(api_url=config["API"]["model_url"], api_token=HUGGINGFACE_TOKEN)
# Initialize RAG Retriever
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.rag_retriever = RagRetrieveWithMeta(
config['rag']['scientific']['faiss'],
config['rag']['scientific']['embeddings'],
config['rag']['diary']['faiss'],
config['rag']['diary']['embeddings'],
config['rag']['weather']['faiss'],
config['rag']['weather']['embeddings'],
config['rag']['insights']['faiss'],
config['rag']['insights']['embeddings']
# Initialize Memory for Chat History
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
self.persona_history = []
def update_region(self,region):
def save_history(self, history):
file_path = f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
# Convert history to structured JSON format
"persona_history": self.persona_history, # 🔹 Now includes both persona and the refinement prompt used
"chat_history": [{"--HUMAN--": msg[0], "--TREE--": msg[1]} for msg in history],
"process_logs": self.process_logs
# Save the JSON data to a file
with open(file_path, "w", encoding="utf-8") as file:
json.dump(history_json, file, indent=4)
return file_path # Returns the file path for download
def get_step_inputs(self, step_inputs):
Fetches and formats the required inputs for a step.
- Uses dictionary-style input definitions from `parameters` in `config.yaml`.
- If an input isn't found, it checks `generated_inputs` for dynamically created values.
# Check in parameters first
for category in self.parameters.values(): # `parameters` now has grouped dictionaries
input_data[item] = category[item]
# If not found, check in dynamically generated inputs
if item in self.generated_inputs:
input_data[item] = self.generated_inputs[item]
# ======== FUNCTION to store and return llm outputs ===============
# =================================================================
def store_result_and_return(self, response, _output=None):
# 🛑 Debug: Check if step_output is set correctly
raise ValueError("⚠️ ERROR: `step_output` is missing in store_result_and_return()!")
# ✅ Store response in generated_inputs dictionary
self.generated_inputs[_output] = response
print(f"\n📝 Stored Result -> `{_output}`:\n{response}\n")
return {_output: response}
# ======== FUNCTION to format llm response ========================
# =================================================================
# def extract_response(self, response: str):
# # Use regex to find everything AFTER '</think>'
# match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
# # If there's a match, return only the actual response (after '</think>')
# return match.group(1).strip() if match else response
def generate_persona_step(self, _step_index):
_inputs = self.persona_steps[_step_index]["inputs"]
_output = self.persona_steps[_step_index]["output"]
_prompt = self.prompts[self.persona_steps[_step_index]["prompt"]]["template"]
required_inputs = self.get_step_inputs(_inputs)
formatted_prompt = _prompt.format(**required_inputs)
response = self.llm.invoke(formatted_prompt)
clean_response = self.extract_response(response)
self.persona_history.append({
return self.store_result_and_return(clean_response, _output=_output)
def get_refinement_prompt(self):
return self.prompts[self.persona_steps[1]["prompt"]]["template"]
def refine_persona(self):
persona_data = self.generate_persona_step(1)
return persona_data["refined_persona"]
def select_regional_instruction(self, region):
if region == "Czechia/Travný":
print(self.roles['czechia'])
return self.roles['czechia']
elif region == "Brazil/Viçosa":
print(self.roles['brazil'])
return self.roles['brazil']
elif region == "Korea/Hongcheon":
print(self.roles['korea'])
return self.roles['korea']
return self.roles['czechia']
def select_regional_manifest(self, region):
if region == "Czechia/Travný":
return self.manifests['czechia']['template']
elif region == "Brazil/Viçosa":
return self.manifests['brazil']['template']
elif region == "Korea/Hongcheon":
return self.manifests['korea']['template']
return self.manifests['czechia']['template']
def clean_text(self, text):
# Remove occurrences of '---' and any whitespace immediately following it
plain_text = re.sub(r'---\s*', '', text)
# Remove excessive whitespace (including newlines) and trim leading/trailing spaces
clean_text = re.sub(r'\s+', ' ', plain_text).strip()
# ===== QUESTION CLASSIFICATION ==========================
# ========================================================
# Function to format response - eliminate reasoning
def extract_response(self,response: str):
# Use regex to find everything AFTER '</think>'
match = re.search(r"</think>\s*(.*)", response, re.DOTALL)
# If there's a match, return only the actual response (after '</think>')
return match.group(1).strip() if match else response
def validate_question(self, question, history, region, manifest):
validation_prompt_template = self.prompts['validator']['template']
prompt = validation_prompt_template.format(question=question, history=history, region=region, manifest=manifest)
validation = self.extract_response(self.llm.invoke(prompt).strip().lower())
# Function to classify question
def classify_question(self, question, classifier_prompt_template):
# classifier_prompt_template = prompts["hop_classification"]["template"]
prompt = classifier_prompt_template.format(question=question)
classification_raw = self.extract_response(self.llm.invoke(prompt).strip().lower())
classification = self.extract_classification(classification_raw)
results = [c.strip() for c in classification.split(",")]
def filter_question_type(self, question):
if "FACTCHECK" in question:
def extract_classification(self, classification_raw):
if "historical" in classification_raw:
classification.append("historical")
if "scientific" in classification_raw:
classification.append("scientific")
if "weather" in classification_raw:
classification.append("weather")
if "insights" in classification_raw:
classification.append("insights")
if "no_rag" in classification_raw:
classification.append("no_rag")
results = ", ".join(classification) #[c.strip() for c in classification.split(",")]
# ===== RAG RETRIEVAL ====================================
# ========================================================
# Function to Handle Multi-Hop RAG Queries
def retrieve_information(self, query, classifications):
if "no_rag" in classifications:
return {},{} # Skip retrieval
retrieved_docs = {} # ✅ Store results by category
for mode in classifications:
print(f"----> current mode: {mode}")
retrieved_texts = self.rag_retriever.rag_wrapper(query, mode=mode)
plain_texts = [doc['content'] for doc in retrieved_texts]
metadata = [doc['metadata'] for doc in retrieved_texts]
query_segment = " ".join(query[:5])
key = query_segment.replace(" ", "_")
retrieved_docs[mode] = plain_texts # ✅ Organized by mode
retrieved_metadata[mode] = metadata
return retrieved_docs, retrieved_metadata
# ===== GENERATE TREE RESPONSE ===========================
# ========================================================
def generate_tree_response(self, question, history ):
# init dictionary to hold full llm process logs
print(f"-----> RESPONDING FOR {region}")
print (self.instruction + self.prompts['tree_response']['template'])
# Add question to process log
process_log['journalist_question'] = question
print(f"Validating question accross region {region}")
# Step 00: Validate prompt
validation_result = self.validate_question(question, history, region, self.manifest)
if(validation_result=="invalid"):
response_template = self.responses['invalid']['template']
response = response_template.format(region=region)
return history + [(question, response)], ""
# Step 0: Get prompt type
prompt_type = self.filter_question_type(question)
# Step 1: Classify Question (Possibly Multi-Hop)
classifications = self.classify_question(question, self.prompts["hop_classification"]["template"])
print(f"1. CLASSIFICATION -- {classifications}")
process_log['classification']=classifications
# Step 2: Retrieve Relevant Knowledge (Categorized)
retrieved_info, retrieved_meta = self.retrieve_information(question, classifications)
print(f"2. RAG RESULTS -- {retrieved_info}")
print(f"2. RAG METADATA -- {retrieved_meta}")
process_log['retrieved_data'] = retrieved_info
process_log['retrieved_metadata'] = retrieved_meta
# Step 3: Get Chat History for Context
chat_history = self.memory.load_memory_variables({})["chat_history"]
# ====================================================================
# ========= PROCEED ACCORDING TO PROMPT TYPE =========================
if prompt_type == "factcheck":
clean_paragraph = question.split("FACTCHECK", 1)[-1].strip()
factcheck_result = self.fact_check_paragraph(clean_paragraph, retrieved_info, retrieved_meta)
#factcheck_result = self.fact_check_paragraph_meta(clean_paragraph, retrieved_info, retrieved_meta)
process_log['factcheck_result'] = factcheck_result
self.process_logs.append({f"_process_step_{self.process_step}": process_log})
return history + [(question, factcheck_result)], ""
# Step 4: Tree Generates Thoughts Before Answering (Always Challenges & Asks a Question)
tree_thoughts = self.generate_tree_thoughts(self.instruction, question, retrieved_info, chat_history)
print(f"3. INTERNAL THOUGHTS -- {tree_thoughts}")
process_log['tree_thoughts'] = tree_thoughts
tree_response_prompt_template = self.instruction + self.prompts['tree_response']['template']
# Step 5: Format Final Response Prompt
formatted_prompt = tree_response_prompt_template.format(
chat_history=chat_history,
retrieved_info="\n".join(retrieved_info) if retrieved_info else "No external knowledge needed.",
tree_thoughts=tree_thoughts
# Step 6: Invoke HuggingFaceAPI for Response
#raw = self.llm.invoke(formatted_prompt)
#print(f"================== RAW DEEPSEEK ANSWER - {raw}")
response = self.extract_response(self.llm.invoke(formatted_prompt).strip())
clean_response = self.clean_text(response)
print(f"4. RAW RESPONSE -- {clean_response}")
process_log['tree_response'] = clean_response
# Step 7: Save chat memory
self.memory.chat_memory.add_user_message(question)
self.memory.chat_memory.add_ai_message(clean_response)
# add process log to history
self.process_logs.append({
f"_process_step_{self.process_step}": process_log
# print(f"TREE: {stylized_response}")
return history + [(question, clean_response)], "" #response
def fact_check_paragraph(self, paragraph: str, retrieved_info: dict, retrieved_meta: dict) -> str:
Run fact-checking on a paragraph using all retrieved info by classification mode.
# Step 1: Format all RAG results
for classification, docs in retrieved_info.items():
sources.append(f"[Source {idx} - {classification.upper()}]\n{doc}")
# Combine sources for prompt
formatted_rag = "\n\n".join(sources)
# Step 2: Load prompt template
fact_check_template = self.prompts['tree_validate']['template']
formatted_prompt = fact_check_template.format(
retrieved_info=formatted_rag
#print("==== FACTCHECK PROMPT ====\n", formatted_prompt[:3000]) # Trimmed for log readability
response = self.extract_response(self.llm.invoke(formatted_prompt).strip())
source_metadata = {} # e.g., {1: {...}, 2: {...}} from zip()
for classification, metas in retrieved_meta.items():
source_metadata[idx] = meta
enriched_response = self.enrich_factcheck_sources(response, source_metadata)
print("==== FACTCHECK RESPONSE ====\n", enriched_response)
def enrich_factcheck_sources(self, response_text: str, source_metadata: dict) -> str:
Replaces [Source X] in the LLM response with enriched metadata info.
for idx, meta in source_metadata.items():
title = meta.get("item_name", "Unknown Title") #or meta.get("title", "Unknown Title")
author = meta.get("item_author", "Unknown Author") #or meta.get("author", "Unknown Author")
year = meta.get("item_year", "Unknown Year") #or meta.get("year", "Unknown Year")
enriched = f"[Source {idx} – {title}, {author}, {year}]"
# Replace plain [Source X] with enriched version
response_text = re.sub(rf"\[Source {idx}\]", enriched, response_text)
def fact_check_paragraph_meta(self, paragraph: str, retrieved_info: dict, retrieved_meta: dict = None) -> str:
Run fact-checking on a paragraph using all retrieved info and link to source metadata.
for classification, docs in retrieved_info.items():
metadata_list = retrieved_meta.get(classification, []) if retrieved_meta else [{}] * len(docs)
for doc, meta in zip(docs, metadata_list):
title = meta.get('item_name') or meta.get('title', 'Unknown Title')
author = meta.get('item_author') or meta.get('author', 'Unknown Author')
year = meta.get('item_year') or meta.get('year', 'Unknown Year')
meta_info = f"Title: {title}\nAuthor: {author}\nYear: {year}"
meta_info += f"\nURL: {url}"
source_block = f"[Source {idx} - {classification.upper()}]\n{meta_info}\n\n{doc}"
sources.append(source_block)
formatted_rag = "\n\n".join(sources)
print(f"==== FACTCHECK RESPONSE ==== {formatted_rag}")
fact_check_template = self.prompts['tree_validate']['template']
formatted_prompt = fact_check_template.format(
retrieved_info=formatted_rag
#print("==== FACTCHECK PROMPT ====\n", formatted_prompt[:3000]) # Truncate if long
response = self.extract_response(self.llm.invoke(formatted_prompt).strip())
print("==== FACTCHECK RESPONSE ====\n", response)
def generate_tree_thoughts(self,instruction,question, retrieved_info, chat_history):
"""Before answering, the tree decides how to challenge and what question to ask."""
#instruction = self.parameters['general_instruction']
thought_prompt_template = instruction + self.prompts['tree_thought']['template']
formatted_thought_prompt = thought_prompt_template.format(
chat_history=chat_history,
retrieved_info=retrieved_info,
print(formatted_thought_prompt)
# ✅ Now we correctly invoke the LLM using the full thought template
tree_thoughts = self.extract_response(
self.llm.invoke(formatted_thought_prompt).strip().lower()) # llm.invoke(thought_prompt)
"""Resets the conversation history in memory."""
self.persona_history = []
return [] # Return empty history to reset UI