Files
langchain-hn-rag/indexing.py
2025-07-01 14:37:52 +02:00

372 lines
13 KiB
Python

import logging
import os
from typing import Iterable, TypedDict
import langchain
import langchain.chat_models
import langchain.prompts
import langchain.text_splitter
import langchain_core
import langchain_core.documents
import langchain_openai
import langchain_weaviate
import langgraph.graph
import slack
import weaviate
from hn import HackerNewsClient, Story
from scrape import JinaScraper
NUM_STORIES = 20
USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"]
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing
llm = langchain.chat_models.init_chat_model(
model="gpt-4o-mini", model_provider="openai"
)
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-large")
weaviate_client = weaviate.connect_to_local()
vector_store = langchain_weaviate.WeaviateVectorStore(
weaviate_client,
index_name="hn_stories",
text_key="page_content",
embedding=embeddings,
)
class State(TypedDict):
preferences: Iterable[str]
context: list[langchain_core.documents.Document]
answer: str
summaries: list[dict]
def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State:
# Search for relevant documents (with scores if available)
retrieved_docs = vector_store.similarity_search(
"Show the most interesting articles about the following topics: "
+ ", ".join(state["preferences"]),
k=top_n * 20, # Chunks, not complete stories
return_score=True
if hasattr(vector_store, "similarity_search_with_score")
else False,
)
# If scores are returned, unpack (doc, score) tuples; else, set score to None
docs_with_scores = []
if retrieved_docs and isinstance(retrieved_docs[0], tuple):
for doc, score in retrieved_docs:
docs_with_scores.append((doc, score))
else:
for doc in retrieved_docs:
docs_with_scores.append((doc, None))
# Group chunks by story_id and collect their scores
story_groups = {}
for doc, score in docs_with_scores:
story_id = doc.metadata.get("story_id")
if story_id not in story_groups:
story_groups[story_id] = []
story_groups[story_id].append((doc, score))
# Aggregate max score per story and reconstruct complete stories
story_scores = {}
complete_stories = []
for story_id, chunks_scores in story_groups.items():
chunks = [doc for doc, _ in chunks_scores]
scores = [s for _, s in chunks_scores if s is not None]
max_score = max(scores) if scores else None
story_scores[story_id] = max_score
if len(chunks) == 1:
complete_stories.append((chunks[0], max_score))
else:
combined_content = "\n\n".join(
chunk.page_content
for chunk in sorted(
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
)
)
complete_story = langchain_core.documents.Document(
page_content=combined_content,
metadata=chunks[0].metadata, # Use metadata from first chunk
)
complete_stories.append((complete_story, max_score))
# Sort stories by max_score descending (None scores go last)
complete_stories_sorted = sorted(
complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True
)
# Return top_n stories
top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]]
return {
"preferences": state["preferences"],
"context": top_stories,
"answer": state.get("answer", ""),
"summaries": state.get("summaries", []),
}
def generate_structured_summaries(state: State):
"""Generate structured summaries for each story individually."""
summaries = []
for doc in state["context"]:
# Create a prompt for individual story summarization
prompt = langchain.prompts.PromptTemplate(
input_variables=["preferences", "title", "content", "source", "categories"],
template=(
"You are a helpful assistant that summarizes technology articles.\n\n"
"User preferences: {preferences}\n\n"
"Article title: {title}\n"
"Article categories: {categories}\n"
"Article content: {content}\n"
"Source URL: {source}\n\n"
"Use an informative but not too formal tone.\n"
"Please provide:\n"
"1. A concise summary (around 50 words) that highlights the key insights from the article.\n"
"2. The single user preference that this article best matches (or 'Other' if none match well)\n\n"
"Format your response as:\n"
"PREFERENCE: [preference name or 'Other']\n"
"SUMMARY: [your summary here]\n"
),
)
messages = prompt.invoke(
{
"preferences": ", ".join(state["preferences"]),
"title": doc.metadata.get("title", "Unknown Title"),
"content": doc.page_content[:5000], # Limit content length for LLM
"source": doc.metadata.get("source", ""),
"categories": ", ".join(doc.metadata.get("categories", [])),
}
)
response = llm.invoke(messages).content
# Parse the LLM response to extract preference and summary
response_text = response if isinstance(response, str) else str(response)
lines = response_text.strip().split("\n")
matched_preference = "Other"
summary_text = response_text
for line in lines:
if line.startswith("PREFERENCE:"):
matched_preference = line.replace("PREFERENCE:", "").strip()
elif line.startswith("SUMMARY:"):
summary_text = line.replace("SUMMARY:", "").strip()
# If we didn't find the structured format, use the whole response as summary
if not any(line.startswith("SUMMARY:") for line in lines):
summary_text = response_text.strip()
summaries.append(
{
"title": doc.metadata.get("title", "Unknown Title"),
"summary": summary_text,
"source_url": doc.metadata.get("source", ""),
"categories": doc.metadata.get("categories", []),
"story_id": doc.metadata.get("story_id"),
"matched_preference": matched_preference,
}
)
return {"summaries": summaries}
def group_stories_by_preference(
summaries: list[dict], preferences: list[str]
) -> dict[str, list[dict]]:
"""Group stories by their matched preferences in the order of user preferences."""
preference_groups = {}
# Group stories by the LLM-determined preference matching
for summary in summaries:
matched_preference = summary.get("matched_preference", "Other")
if matched_preference not in preference_groups:
preference_groups[matched_preference] = []
preference_groups[matched_preference].append(summary)
# Create ordered groups based on user preferences
ordered_groups = {}
# Add groups for user preferences in order
for preference in preferences:
if preference in preference_groups:
ordered_groups[preference] = preference_groups[preference]
# Add "Other" group at the end if it exists
if "Other" in preference_groups:
ordered_groups["Other"] = preference_groups["Other"]
return ordered_groups
def create_slack_blocks(summaries: list[dict], preferences: list[str]) -> list[dict]:
"""Convert structured summaries into Slack block format grouped by user preferences."""
grouped_stories = group_stories_by_preference(summaries, preferences)
return slack.format_slack_blocks(grouped_stories)
def run_structured_query(
preferences: Iterable[str],
) -> list[dict]:
"""Run query and return structured summary data."""
graph_builder = langgraph.graph.StateGraph(State).add_sequence(
[retrieve, generate_structured_summaries]
)
graph_builder.add_edge(langgraph.graph.START, "retrieve")
graph = graph_builder.compile()
response = graph.invoke(
State(preferences=preferences, context=[], answer="", summaries=[])
)
summaries = response["summaries"]
return summaries
def get_existing_story_ids() -> set[str]:
"""Get the IDs of stories that already exist in the vector store."""
try:
collection = vector_store._collection
existing_ids = set()
for doc in collection.iterator():
story_id = doc.properties.get("story_id")
if story_id:
existing_ids.add(story_id)
return existing_ids
except Exception:
logging.warning("Could not retrieve existing story IDs", exc_info=True)
return set()
async def fetch_hn_top_stories(
limit: int = 10,
) -> list[langchain_core.documents.Document]:
hn = HackerNewsClient()
stories = await hn.get_top_stories(limit=limit)
# Get existing story IDs to avoid re-fetching
existing_ids = get_existing_story_ids()
logging.info(f"Existing story IDs: {len(existing_ids)} found in vector store")
new_stories = [story for story in stories if story.id not in existing_ids]
print(f"Found {len(stories)} top stories, {len(new_stories)} are new")
if not new_stories:
print("No new stories to fetch")
return []
contents = {}
# Fetch content for each new story asynchronously
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
async def _fetch_content(story: Story) -> tuple[str, str]:
try:
if not story.url:
return story.id, story.title
return story.id, await scraper.get_content(story.url)
except Exception as e:
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
return story.id, story.title # Fallback to title if content fetch fails
tasks = [_fetch_content(story) for story in new_stories]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out exceptions and convert to dict
contents = {}
for result in results:
if isinstance(result, Exception):
logging.error(f"Task failed with exception: {result}")
continue
if isinstance(result, tuple) and len(result) == 2:
story_id, content = result
contents[story_id] = content
documents = [
langchain_core.documents.Document(
page_content=contents[story.id],
metadata={
"story_id": story.id,
"title": story.title,
"source": story.url,
"created_at": story.created_at.isoformat(),
},
)
for story in new_stories
]
return documents
async def main():
if ENABLE_MLFLOW_TRACING:
import mlflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("langchain-rag-hn")
mlflow.langchain.autolog()
# 1. Load only new stories
new_stories = await fetch_hn_top_stories(limit=NUM_STORIES)
if new_stories:
print(f"Processing {len(new_stories)} new stories")
# Categorize stories (optional)
from classify import categorize
for story in new_stories:
categories = categorize(story, llm)
story.metadata["categories"] = list(categories)
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}")
# 2. Split
documents_to_store = []
for story in new_stories:
# If article is short enough, store as-is
if len(story.page_content) <= 3000:
documents_to_store.append(story)
else:
# For very long articles, chunk but keep story metadata
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
add_start_index=True,
)
chunks = splitter.split_documents([story])
# Add chunk info to metadata
for i, chunk in enumerate(chunks):
chunk.metadata["chunk_index"] = i
chunk.metadata["total_chunks"] = len(chunks)
documents_to_store.extend(chunks)
# 3. Store
_ = vector_store.add_documents(documents_to_store)
print(f"Added {len(documents_to_store)} documents to vector store")
else:
print("No new stories to process")
# 4. Query
summaries = run_structured_query(USER_PREFERENCES)
if ENABLE_SLACK:
blocks = create_slack_blocks(summaries, USER_PREFERENCES)
slack.send_message(channel="#ragpull-demo", blocks=blocks)
print(summaries)
if __name__ == "__main__":
import asyncio
logging.basicConfig(level=logging.INFO)
try:
asyncio.run(main())
finally:
weaviate_client.close()