From 311c332b109e409b67fb01aa8a5a2109868f1c37 Mon Sep 17 00:00:00 2001 From: Adrian Rumpold Date: Wed, 2 Jul 2025 12:22:50 +0200 Subject: [PATCH] feat: Use langchain parent document retriever to simplify retrieval logic --- .gitignore | 4 +- indexing.py | 159 ++++++++++++++++------------------------------------ slack.py | 1 + util.py | 48 ++++++++++++++++ 4 files changed, 101 insertions(+), 111 deletions(-) create mode 100644 util.py diff --git a/.gitignore b/.gitignore index 2a5231c..3961ac4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,6 @@ mlruns/ mlartifacts/ # Weaviate vector store -weaviate/ \ No newline at end of file +weaviate/ + +data/ \ No newline at end of file diff --git a/indexing.py b/indexing.py index 2f0c94f..615680e 100644 --- a/indexing.py +++ b/indexing.py @@ -1,3 +1,4 @@ +import asyncio import logging import os from typing import Iterable, TypedDict @@ -5,6 +6,7 @@ from typing import Iterable, TypedDict import langchain import langchain.chat_models import langchain.prompts +import langchain.retrievers import langchain.text_splitter import langchain_core import langchain_core.documents @@ -16,11 +18,12 @@ import slack import weaviate from hn import HackerNewsClient, Story from scrape import JinaScraper +from util import DocumentLocalFileStore -NUM_STORIES = 20 -USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"] +NUM_STORIES = 40 # Number of top stories to fetch from Hacker News +USER_PREFERENCES = ["Machine Learning", "Programming", "DevOps"] ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var -ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing +ENABLE_MLFLOW_TRACING = True # Use MLflow (at http://localhost:5000) for tracing llm = langchain.chat_models.init_chat_model( @@ -36,6 +39,18 @@ vector_store = langchain_weaviate.WeaviateVectorStore( embedding=embeddings, ) +splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( + chunk_size=2000, + chunk_overlap=200, +) + +doc_store = DocumentLocalFileStore(root_path="data/documents") +retriever = langchain.retrievers.ParentDocumentRetriever( + vectorstore=vector_store, + docstore=doc_store, + child_splitter=splitter, +) + class State(TypedDict): preferences: Iterable[str] @@ -44,70 +59,15 @@ class State(TypedDict): summaries: list[dict] -def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State: - # Search for relevant documents (with scores if available) - retrieved_docs = vector_store.similarity_search( - "Show the most interesting articles about the following topics: " - + ", ".join(state["preferences"]), - k=top_n * 20, # Chunks, not complete stories - return_score=True - if hasattr(vector_store, "similarity_search_with_score") - else False, - ) +def retrieve_docs(state: State, top_n: int = 3): + """Retrieve relevant documents based on user preferences.""" - # If scores are returned, unpack (doc, score) tuples; else, set score to None - docs_with_scores = [] - if retrieved_docs and isinstance(retrieved_docs[0], tuple): - for doc, score in retrieved_docs: - docs_with_scores.append((doc, score)) - else: - for doc in retrieved_docs: - docs_with_scores.append((doc, None)) + docs = [] + for preference in state["preferences"]: + logging.info(f"Retrieving documents for preference: {preference}") + docs.extend(retriever.invoke(preference)[:top_n]) - # Group chunks by story_id and collect their scores - story_groups = {} - for doc, score in docs_with_scores: - story_id = doc.metadata.get("story_id") - if story_id not in story_groups: - story_groups[story_id] = [] - story_groups[story_id].append((doc, score)) - - # Aggregate max score per story and reconstruct complete stories - story_scores = {} - complete_stories = [] - for story_id, chunks_scores in story_groups.items(): - chunks = [doc for doc, _ in chunks_scores] - scores = [s for _, s in chunks_scores if s is not None] - max_score = max(scores) if scores else None - story_scores[story_id] = max_score - if len(chunks) == 1: - complete_stories.append((chunks[0], max_score)) - else: - combined_content = "\n\n".join( - chunk.page_content - for chunk in sorted( - chunks, key=lambda x: x.metadata.get("chunk_index", 0) - ) - ) - complete_story = langchain_core.documents.Document( - page_content=combined_content, - metadata=chunks[0].metadata, # Use metadata from first chunk - ) - complete_stories.append((complete_story, max_score)) - - # Sort stories by max_score descending (None scores go last) - complete_stories_sorted = sorted( - complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True - ) - - # Return top_n stories - top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]] - return { - "preferences": state["preferences"], - "context": top_stories, - "answer": state.get("answer", ""), - "summaries": state.get("summaries", []), - } + return {"context": docs} def generate_structured_summaries(state: State): @@ -217,9 +177,9 @@ def run_structured_query( ) -> list[dict]: """Run query and return structured summary data.""" graph_builder = langgraph.graph.StateGraph(State).add_sequence( - [retrieve, generate_structured_summaries] + [retrieve_docs, generate_structured_summaries] ) - graph_builder.add_edge(langgraph.graph.START, "retrieve") + graph_builder.add_edge(langgraph.graph.START, "retrieve_docs") graph = graph_builder.compile() response = graph.invoke( @@ -246,6 +206,7 @@ def get_existing_story_ids() -> set[str]: async def fetch_hn_top_stories( limit: int = 10, + force_fetch: bool = False, ) -> list[langchain_core.documents.Document]: hn = HackerNewsClient() stories = await hn.get_top_stories(limit=limit) @@ -256,11 +217,15 @@ async def fetch_hn_top_stories( new_stories = [story for story in stories if story.id not in existing_ids] - print(f"Found {len(stories)} top stories, {len(new_stories)} are new") + logging.info(f"Found {len(stories)} top stories, {len(new_stories)} are new") if not new_stories: - print("No new stories to fetch") - return [] + if not force_fetch: + logging.info("No new stories to fetch") + return [] + else: + logging.info("Force fetching all top stories regardless of existing IDs") + new_stories = stories contents = {} @@ -268,26 +233,18 @@ async def fetch_hn_top_stories( scraper = JinaScraper(os.getenv("JINA_API_KEY")) async def _fetch_content(story: Story) -> tuple[str, str]: - try: - if not story.url: - return story.id, story.title - return story.id, await scraper.get_content(story.url) - except Exception as e: - logging.warning(f"Failed to fetch content for story {story.id}: {e}") - return story.id, story.title # Fallback to title if content fetch fails + if not story.url: + return story.id, story.title + return story.id, await scraper.get_content(story.url) tasks = [_fetch_content(story) for story in new_stories] - results = await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*tasks) # Filter out exceptions and convert to dict contents = {} for result in results: - if isinstance(result, Exception): - logging.error(f"Task failed with exception: {result}") - continue - if isinstance(result, tuple) and len(result) == 2: - story_id, content = result - contents[story_id] = content + story_id, content = result + contents[story_id] = content documents = [ langchain_core.documents.Document( @@ -324,31 +281,15 @@ async def main(): for story in new_stories: categories = categorize(story, llm) story.metadata["categories"] = list(categories) - print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}") + print( + f"Story ID {story.metadata["story_id"]} ({story.metadata["title"]}) categorized as: {categories}" + ) - # 2. Split - documents_to_store = [] - for story in new_stories: - # If article is short enough, store as-is - if len(story.page_content) <= 3000: - documents_to_store.append(story) - else: - # For very long articles, chunk but keep story metadata - splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( - chunk_size=2000, - chunk_overlap=200, - add_start_index=True, - ) - chunks = splitter.split_documents([story]) - # Add chunk info to metadata - for i, chunk in enumerate(chunks): - chunk.metadata["chunk_index"] = i - chunk.metadata["total_chunks"] = len(chunks) - documents_to_store.extend(chunks) - - # 3. Store - _ = vector_store.add_documents(documents_to_store) - print(f"Added {len(documents_to_store)} documents to vector store") + # 2. Split & 3. Store + retriever.add_documents( + new_stories, ids=[doc.metadata["story_id"] for doc in new_stories] + ) + print(f"Added {len(new_stories)} documents to document store") else: print("No new stories to process") @@ -361,8 +302,6 @@ async def main(): if __name__ == "__main__": - import asyncio - logging.basicConfig(level=logging.INFO) try: diff --git a/slack.py b/slack.py index 699ee18..ac63f11 100644 --- a/slack.py +++ b/slack.py @@ -77,6 +77,7 @@ def send_message(channel: str, blocks: list) -> None: text="Tech updates", blocks=blocks, unfurl_links=False, + unfurl_media=False, ) response.validate() logging.info(f"Message sent successfully to channel {channel}") diff --git a/util.py b/util.py new file mode 100644 index 0000000..ba072c8 --- /dev/null +++ b/util.py @@ -0,0 +1,48 @@ +import json +from collections.abc import Iterator + +from langchain.schema import BaseStore, Document +from langchain.storage import LocalFileStore + + +class DocumentLocalFileStore(BaseStore[str, Document]): + def __init__(self, root_path: str): + self.store = LocalFileStore(root_path) + + def mget(self, keys): + """Get multiple documents by keys""" + results = [] + for key in keys: + try: + serialized = self.store.mget([key])[0] + if serialized: + # Deserialize the document + doc_dict = json.loads(serialized.decode("utf-8")) + doc = Document( + page_content=doc_dict["page_content"], + metadata=doc_dict.get("metadata", {}), + ) + results.append(doc) + else: + results.append(None) + except: + results.append(None) + return results + + def mset(self, key_value_pairs): + """Set multiple key-value pairs""" + serialized_pairs = [] + for key, doc in key_value_pairs: + # Serialize the document + doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata} + serialized = json.dumps(doc_dict).encode("utf-8") + serialized_pairs.append((key, serialized)) + + self.store.mset(serialized_pairs) + + def mdelete(self, keys): + """Delete multiple keys""" + self.store.mdelete(keys) + + def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: + return self.store.yield_keys(prefix=prefix)