Files
langchain-hn-rag/indexing.py
2025-07-01 13:32:18 +02:00

253 lines
8.3 KiB
Python

import logging
import os
from typing import Iterable, TypedDict
import langchain
import langchain.chat_models
import langchain.prompts
import langchain.text_splitter
import langchain_core
import langchain_core.documents
import langchain_openai
import langchain_weaviate
import langgraph.graph
import slack
import weaviate
from hn import HackerNewsClient, Story
from scrape import JinaScraper
NUM_STORIES = 20
USER_PREFERENCES = ["Machine Learning", "Linux", "Open-Source"]
ENABLE_SLACK = False # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing
llm = langchain.chat_models.init_chat_model(
model="gpt-4o-mini", model_provider="openai"
)
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-large")
weaviate_client = weaviate.connect_to_local()
vector_store = langchain_weaviate.WeaviateVectorStore(
weaviate_client,
index_name="hn_stories",
text_key="page_content",
embedding=embeddings,
)
class State(TypedDict):
preferences: Iterable[str]
context: list[langchain_core.documents.Document]
answer: str
def retrieve(state: State):
# Search for relevant documents
retrieved_docs = vector_store.similarity_search(
"Categories: " + ", ".join(state["preferences"]), k=10
)
# If you're using chunks, group them back into complete stories
story_groups = {}
for doc in retrieved_docs:
story_id = doc.metadata.get("story_id")
if story_id not in story_groups:
story_groups[story_id] = []
story_groups[story_id].append(doc)
# Reconstruct complete stories or use the best chunk per story
complete_stories = []
for story_id, chunks in story_groups.items():
if len(chunks) == 1:
complete_stories.append(chunks[0])
else:
# Combine chunks back into complete story
combined_content = "\n\n".join(
chunk.page_content
for chunk in sorted(
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
)
)
complete_story = langchain_core.documents.Document(
page_content=combined_content,
metadata=chunks[0].metadata, # Use metadata from first chunk
)
complete_stories.append(complete_story)
return {"context": complete_stories[:5]} # Limit to top 5 stories
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
prompt = langchain.prompts.PromptTemplate(
input_variables=["preferences", "context"],
template=(
"You are a helpful assistant that can provide updates on technology topics based on the topics a user has expressed interest in and additional context.\n\n"
"Please respond in Markdown format and group your answers based on the categories of the items in the context.\n"
"If applicable, add hyperlinks to the original source as part of the headline for each story.\n"
"Limit your summaries to approximately 100 words per item.\n\n"
"Preferences: {preferences}\n\n"
"Context:\n{context}\n\n"
"Answer:"
),
)
messages = prompt.invoke(
{"preferences": state["preferences"], "context": docs_content}
)
response = llm.invoke(messages)
return {"answer": response.content}
def run_query(preferences: Iterable[str]) -> str:
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(langgraph.graph.START, "retrieve")
graph = graph_builder.compile()
response = graph.invoke(State(preferences=preferences, context=[], answer=""))
return response["answer"]
def get_existing_story_ids() -> set[str]:
"""Get the IDs of stories that already exist in the vector store."""
try:
collection = vector_store._collection
existing_ids = set()
for doc in collection.iterator():
story_id = doc.properties.get("story_id")
if story_id:
existing_ids.add(story_id)
return existing_ids
except Exception:
logging.warning("Could not retrieve existing story IDs", exc_info=True)
return set()
async def fetch_hn_top_stories(
limit: int = 10,
) -> list[langchain_core.documents.Document]:
hn = HackerNewsClient()
stories = await hn.get_top_stories(limit=limit)
# Get existing story IDs to avoid re-fetching
existing_ids = get_existing_story_ids()
logging.info(f"Existing story IDs: {len(existing_ids)} found in vector store")
new_stories = [story for story in stories if story.id not in existing_ids]
print(f"Found {len(stories)} top stories, {len(new_stories)} are new")
if not new_stories:
print("No new stories to fetch")
return []
contents = {}
# Fetch content for each new story asynchronously
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
async def _fetch_content(story: Story) -> tuple[str, str]:
try:
if not story.url:
return story.id, story.title
return story.id, await scraper.get_content(story.url)
except Exception as e:
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
return story.id, story.title # Fallback to title if content fetch fails
tasks = [_fetch_content(story) for story in new_stories]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out exceptions and convert to dict
contents = {}
for result in results:
if isinstance(result, Exception):
logging.error(f"Task failed with exception: {result}")
continue
if isinstance(result, tuple) and len(result) == 2:
story_id, content = result
contents[story_id] = content
documents = [
langchain_core.documents.Document(
page_content=contents[story.id],
metadata={
"story_id": story.id,
"title": story.title,
"source": story.url,
"created_at": story.created_at.isoformat(),
},
)
for story in new_stories
]
return documents
async def main():
if ENABLE_MLFLOW_TRACING:
import mlflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("langchain-rag-hn")
mlflow.langchain.autolog()
# 1. Load only new stories
new_stories = await fetch_hn_top_stories(limit=NUM_STORIES)
if new_stories:
print(f"Processing {len(new_stories)} new stories")
# Categorize stories (optional)
from classify import categorize
for story in new_stories:
categories = categorize(story, llm)
story.metadata["categories"] = list(categories)
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}")
# 2. Split
documents_to_store = []
for story in new_stories:
# If article is short enough, store as-is
if len(story.page_content) <= 3000:
documents_to_store.append(story)
else:
# For very long articles, chunk but keep story metadata
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
add_start_index=True,
)
chunks = splitter.split_documents([story])
# Add chunk info to metadata
for i, chunk in enumerate(chunks):
chunk.metadata["chunk_index"] = i
chunk.metadata["total_chunks"] = len(chunks)
documents_to_store.extend(chunks)
# 3. Store
_ = vector_store.add_documents(documents_to_store)
print(f"Added {len(documents_to_store)} documents to vector store")
else:
print("No new stories to process")
# 4. Query
answer = run_query(USER_PREFERENCES)
print(answer)
if ENABLE_SLACK:
slack.send_message(channel="#ragpull-demo", text=answer)
if __name__ == "__main__":
import asyncio
logging.basicConfig(level=logging.INFO)
try:
asyncio.run(main())
finally:
weaviate_client.close()