feat: Use langchain parent document retriever to simplify retrieval logic
This commit is contained in:
		
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -15,3 +15,5 @@ mlartifacts/ | ||||
|  | ||||
| # Weaviate vector store | ||||
| weaviate/ | ||||
|  | ||||
| data/ | ||||
							
								
								
									
										159
									
								
								indexing.py
									
									
									
									
									
								
							
							
						
						
									
										159
									
								
								indexing.py
									
									
									
									
									
								
							| @@ -1,3 +1,4 @@ | ||||
| import asyncio | ||||
| import logging | ||||
| import os | ||||
| from typing import Iterable, TypedDict | ||||
| @@ -5,6 +6,7 @@ from typing import Iterable, TypedDict | ||||
| import langchain | ||||
| import langchain.chat_models | ||||
| import langchain.prompts | ||||
| import langchain.retrievers | ||||
| import langchain.text_splitter | ||||
| import langchain_core | ||||
| import langchain_core.documents | ||||
| @@ -16,11 +18,12 @@ import slack | ||||
| import weaviate | ||||
| from hn import HackerNewsClient, Story | ||||
| from scrape import JinaScraper | ||||
| from util import DocumentLocalFileStore | ||||
|  | ||||
| NUM_STORIES = 20 | ||||
| USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"] | ||||
| NUM_STORIES = 40  # Number of top stories to fetch from Hacker News | ||||
| USER_PREFERENCES = ["Machine Learning", "Programming", "DevOps"] | ||||
| ENABLE_SLACK = True  # Send updates to Slack, need to set SLACK_BOT_TOKEN env var | ||||
| ENABLE_MLFLOW_TRACING = False  # Use MLflow (at http://localhost:5000) for tracing | ||||
| ENABLE_MLFLOW_TRACING = True  # Use MLflow (at http://localhost:5000) for tracing | ||||
|  | ||||
|  | ||||
| llm = langchain.chat_models.init_chat_model( | ||||
| @@ -36,6 +39,18 @@ vector_store = langchain_weaviate.WeaviateVectorStore( | ||||
|     embedding=embeddings, | ||||
| ) | ||||
|  | ||||
| splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||
|     chunk_size=2000, | ||||
|     chunk_overlap=200, | ||||
| ) | ||||
|  | ||||
| doc_store = DocumentLocalFileStore(root_path="data/documents") | ||||
| retriever = langchain.retrievers.ParentDocumentRetriever( | ||||
|     vectorstore=vector_store, | ||||
|     docstore=doc_store, | ||||
|     child_splitter=splitter, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class State(TypedDict): | ||||
|     preferences: Iterable[str] | ||||
| @@ -44,70 +59,15 @@ class State(TypedDict): | ||||
|     summaries: list[dict] | ||||
|  | ||||
|  | ||||
| def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State: | ||||
|     # Search for relevant documents (with scores if available) | ||||
|     retrieved_docs = vector_store.similarity_search( | ||||
|         "Show the most interesting articles about the following topics: " | ||||
|         + ", ".join(state["preferences"]), | ||||
|         k=top_n * 20,  # Chunks, not complete stories | ||||
|         return_score=True | ||||
|         if hasattr(vector_store, "similarity_search_with_score") | ||||
|         else False, | ||||
|     ) | ||||
| def retrieve_docs(state: State, top_n: int = 3): | ||||
|     """Retrieve relevant documents based on user preferences.""" | ||||
|  | ||||
|     # If scores are returned, unpack (doc, score) tuples; else, set score to None | ||||
|     docs_with_scores = [] | ||||
|     if retrieved_docs and isinstance(retrieved_docs[0], tuple): | ||||
|         for doc, score in retrieved_docs: | ||||
|             docs_with_scores.append((doc, score)) | ||||
|     else: | ||||
|         for doc in retrieved_docs: | ||||
|             docs_with_scores.append((doc, None)) | ||||
|     docs = [] | ||||
|     for preference in state["preferences"]: | ||||
|         logging.info(f"Retrieving documents for preference: {preference}") | ||||
|         docs.extend(retriever.invoke(preference)[:top_n]) | ||||
|  | ||||
|     # Group chunks by story_id and collect their scores | ||||
|     story_groups = {} | ||||
|     for doc, score in docs_with_scores: | ||||
|         story_id = doc.metadata.get("story_id") | ||||
|         if story_id not in story_groups: | ||||
|             story_groups[story_id] = [] | ||||
|         story_groups[story_id].append((doc, score)) | ||||
|  | ||||
|     # Aggregate max score per story and reconstruct complete stories | ||||
|     story_scores = {} | ||||
|     complete_stories = [] | ||||
|     for story_id, chunks_scores in story_groups.items(): | ||||
|         chunks = [doc for doc, _ in chunks_scores] | ||||
|         scores = [s for _, s in chunks_scores if s is not None] | ||||
|         max_score = max(scores) if scores else None | ||||
|         story_scores[story_id] = max_score | ||||
|         if len(chunks) == 1: | ||||
|             complete_stories.append((chunks[0], max_score)) | ||||
|         else: | ||||
|             combined_content = "\n\n".join( | ||||
|                 chunk.page_content | ||||
|                 for chunk in sorted( | ||||
|                     chunks, key=lambda x: x.metadata.get("chunk_index", 0) | ||||
|                 ) | ||||
|             ) | ||||
|             complete_story = langchain_core.documents.Document( | ||||
|                 page_content=combined_content, | ||||
|                 metadata=chunks[0].metadata,  # Use metadata from first chunk | ||||
|             ) | ||||
|             complete_stories.append((complete_story, max_score)) | ||||
|  | ||||
|     # Sort stories by max_score descending (None scores go last) | ||||
|     complete_stories_sorted = sorted( | ||||
|         complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True | ||||
|     ) | ||||
|  | ||||
|     # Return top_n stories | ||||
|     top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]] | ||||
|     return { | ||||
|         "preferences": state["preferences"], | ||||
|         "context": top_stories, | ||||
|         "answer": state.get("answer", ""), | ||||
|         "summaries": state.get("summaries", []), | ||||
|     } | ||||
|     return {"context": docs} | ||||
|  | ||||
|  | ||||
| def generate_structured_summaries(state: State): | ||||
| @@ -217,9 +177,9 @@ def run_structured_query( | ||||
| ) -> list[dict]: | ||||
|     """Run query and return structured summary data.""" | ||||
|     graph_builder = langgraph.graph.StateGraph(State).add_sequence( | ||||
|         [retrieve, generate_structured_summaries] | ||||
|         [retrieve_docs, generate_structured_summaries] | ||||
|     ) | ||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve") | ||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve_docs") | ||||
|     graph = graph_builder.compile() | ||||
|  | ||||
|     response = graph.invoke( | ||||
| @@ -246,6 +206,7 @@ def get_existing_story_ids() -> set[str]: | ||||
|  | ||||
| async def fetch_hn_top_stories( | ||||
|     limit: int = 10, | ||||
|     force_fetch: bool = False, | ||||
| ) -> list[langchain_core.documents.Document]: | ||||
|     hn = HackerNewsClient() | ||||
|     stories = await hn.get_top_stories(limit=limit) | ||||
| @@ -256,11 +217,15 @@ async def fetch_hn_top_stories( | ||||
|  | ||||
|     new_stories = [story for story in stories if story.id not in existing_ids] | ||||
|  | ||||
|     print(f"Found {len(stories)} top stories, {len(new_stories)} are new") | ||||
|     logging.info(f"Found {len(stories)} top stories, {len(new_stories)} are new") | ||||
|  | ||||
|     if not new_stories: | ||||
|         print("No new stories to fetch") | ||||
|         return [] | ||||
|         if not force_fetch: | ||||
|             logging.info("No new stories to fetch") | ||||
|             return [] | ||||
|         else: | ||||
|             logging.info("Force fetching all top stories regardless of existing IDs") | ||||
|             new_stories = stories | ||||
|  | ||||
|     contents = {} | ||||
|  | ||||
| @@ -268,26 +233,18 @@ async def fetch_hn_top_stories( | ||||
|     scraper = JinaScraper(os.getenv("JINA_API_KEY")) | ||||
|  | ||||
|     async def _fetch_content(story: Story) -> tuple[str, str]: | ||||
|         try: | ||||
|             if not story.url: | ||||
|                 return story.id, story.title | ||||
|             return story.id, await scraper.get_content(story.url) | ||||
|         except Exception as e: | ||||
|             logging.warning(f"Failed to fetch content for story {story.id}: {e}") | ||||
|             return story.id, story.title  # Fallback to title if content fetch fails | ||||
|         if not story.url: | ||||
|             return story.id, story.title | ||||
|         return story.id, await scraper.get_content(story.url) | ||||
|  | ||||
|     tasks = [_fetch_content(story) for story in new_stories] | ||||
|     results = await asyncio.gather(*tasks, return_exceptions=True) | ||||
|     results = await asyncio.gather(*tasks) | ||||
|  | ||||
|     # Filter out exceptions and convert to dict | ||||
|     contents = {} | ||||
|     for result in results: | ||||
|         if isinstance(result, Exception): | ||||
|             logging.error(f"Task failed with exception: {result}") | ||||
|             continue | ||||
|         if isinstance(result, tuple) and len(result) == 2: | ||||
|             story_id, content = result | ||||
|             contents[story_id] = content | ||||
|         story_id, content = result | ||||
|         contents[story_id] = content | ||||
|  | ||||
|     documents = [ | ||||
|         langchain_core.documents.Document( | ||||
| @@ -324,31 +281,15 @@ async def main(): | ||||
|         for story in new_stories: | ||||
|             categories = categorize(story, llm) | ||||
|             story.metadata["categories"] = list(categories) | ||||
|             print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}") | ||||
|             print( | ||||
|                 f"Story ID {story.metadata["story_id"]} ({story.metadata["title"]}) categorized as: {categories}" | ||||
|             ) | ||||
|  | ||||
|         # 2. Split | ||||
|         documents_to_store = [] | ||||
|         for story in new_stories: | ||||
|             # If article is short enough, store as-is | ||||
|             if len(story.page_content) <= 3000: | ||||
|                 documents_to_store.append(story) | ||||
|             else: | ||||
|                 # For very long articles, chunk but keep story metadata | ||||
|                 splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||
|                     chunk_size=2000, | ||||
|                     chunk_overlap=200, | ||||
|                     add_start_index=True, | ||||
|                 ) | ||||
|                 chunks = splitter.split_documents([story]) | ||||
|                 # Add chunk info to metadata | ||||
|                 for i, chunk in enumerate(chunks): | ||||
|                     chunk.metadata["chunk_index"] = i | ||||
|                     chunk.metadata["total_chunks"] = len(chunks) | ||||
|                 documents_to_store.extend(chunks) | ||||
|  | ||||
|         # 3. Store | ||||
|         _ = vector_store.add_documents(documents_to_store) | ||||
|         print(f"Added {len(documents_to_store)} documents to vector store") | ||||
|         # 2. Split & 3. Store | ||||
|         retriever.add_documents( | ||||
|             new_stories, ids=[doc.metadata["story_id"] for doc in new_stories] | ||||
|         ) | ||||
|         print(f"Added {len(new_stories)} documents to document store") | ||||
|     else: | ||||
|         print("No new stories to process") | ||||
|  | ||||
| @@ -361,8 +302,6 @@ async def main(): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     import asyncio | ||||
|  | ||||
|     logging.basicConfig(level=logging.INFO) | ||||
|  | ||||
|     try: | ||||
|   | ||||
							
								
								
									
										1
									
								
								slack.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								slack.py
									
									
									
									
									
								
							| @@ -77,6 +77,7 @@ def send_message(channel: str, blocks: list) -> None: | ||||
|             text="Tech updates", | ||||
|             blocks=blocks, | ||||
|             unfurl_links=False, | ||||
|             unfurl_media=False, | ||||
|         ) | ||||
|         response.validate() | ||||
|         logging.info(f"Message sent successfully to channel {channel}") | ||||
|   | ||||
							
								
								
									
										48
									
								
								util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								util.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| import json | ||||
| from collections.abc import Iterator | ||||
|  | ||||
| from langchain.schema import BaseStore, Document | ||||
| from langchain.storage import LocalFileStore | ||||
|  | ||||
|  | ||||
| class DocumentLocalFileStore(BaseStore[str, Document]): | ||||
|     def __init__(self, root_path: str): | ||||
|         self.store = LocalFileStore(root_path) | ||||
|  | ||||
|     def mget(self, keys): | ||||
|         """Get multiple documents by keys""" | ||||
|         results = [] | ||||
|         for key in keys: | ||||
|             try: | ||||
|                 serialized = self.store.mget([key])[0] | ||||
|                 if serialized: | ||||
|                     # Deserialize the document | ||||
|                     doc_dict = json.loads(serialized.decode("utf-8")) | ||||
|                     doc = Document( | ||||
|                         page_content=doc_dict["page_content"], | ||||
|                         metadata=doc_dict.get("metadata", {}), | ||||
|                     ) | ||||
|                     results.append(doc) | ||||
|                 else: | ||||
|                     results.append(None) | ||||
|             except: | ||||
|                 results.append(None) | ||||
|         return results | ||||
|  | ||||
|     def mset(self, key_value_pairs): | ||||
|         """Set multiple key-value pairs""" | ||||
|         serialized_pairs = [] | ||||
|         for key, doc in key_value_pairs: | ||||
|             # Serialize the document | ||||
|             doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata} | ||||
|             serialized = json.dumps(doc_dict).encode("utf-8") | ||||
|             serialized_pairs.append((key, serialized)) | ||||
|  | ||||
|         self.store.mset(serialized_pairs) | ||||
|  | ||||
|     def mdelete(self, keys): | ||||
|         """Delete multiple keys""" | ||||
|         self.store.mdelete(keys) | ||||
|  | ||||
|     def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: | ||||
|         return self.store.yield_keys(prefix=prefix) | ||||
		Reference in New Issue
	
	Block a user