feat: Reworked chunking and retrieval logic to operate on entire stories instead of chunks.
This commit is contained in:
		
							
								
								
									
										205
									
								
								indexing.py
									
									
									
									
									
								
							
							
						
						
									
										205
									
								
								indexing.py
									
									
									
									
									
								
							| @@ -1,119 +1,242 @@ | ||||
| import logging | ||||
| import os | ||||
| from typing import TypedDict | ||||
| from typing import Iterable, TypedDict | ||||
|  | ||||
| import langchain | ||||
| import langchain.chat_models | ||||
| import langchain.hub | ||||
| import langchain.prompts | ||||
| import langchain.text_splitter | ||||
| import langchain_chroma | ||||
| import langchain_core | ||||
| import langchain_core.documents | ||||
| import langchain_openai | ||||
| import langchain_weaviate | ||||
| import langgraph.graph | ||||
| import mlflow | ||||
|  | ||||
| import weaviate | ||||
| from hn import HackerNewsClient, Story | ||||
| from scrape import JinaScraper | ||||
|  | ||||
| llm = langchain.chat_models.init_chat_model( | ||||
|     model="gpt-4.1-nano", model_provider="openai" | ||||
|     model="gpt-4o-mini", model_provider="openai" | ||||
| ) | ||||
| embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small") | ||||
| vector_store = langchain_chroma.Chroma( | ||||
|     collection_name="hn_stories", | ||||
|     embedding_function=embeddings, | ||||
|     persist_directory="./chroma_db", | ||||
|     create_collection_if_not_exists=True, | ||||
| embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-large") | ||||
|  | ||||
| weaviate_client = weaviate.connect_to_local() | ||||
| vector_store = langchain_weaviate.WeaviateVectorStore( | ||||
|     weaviate_client, | ||||
|     index_name="hn_stories", | ||||
|     text_key="page_content", | ||||
|     embedding=embeddings, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class State(TypedDict): | ||||
|     question: str | ||||
|     preferences: Iterable[str] | ||||
|     context: list[langchain_core.documents.Document] | ||||
|     answer: str | ||||
|  | ||||
|  | ||||
| # Define application steps | ||||
| def retrieve(state: State): | ||||
|     retrieved_docs = vector_store.similarity_search(state["question"], k=10) | ||||
|     return {"context": retrieved_docs} | ||||
|     # Search for relevant documents | ||||
|     retrieved_docs = vector_store.similarity_search( | ||||
|         "Categories: " + ", ".join(state["preferences"]), k=10 | ||||
|     ) | ||||
|  | ||||
|     # If you're using chunks, group them back into complete stories | ||||
|     story_groups = {} | ||||
|     for doc in retrieved_docs: | ||||
|         story_id = doc.metadata.get("story_id") | ||||
|         if story_id not in story_groups: | ||||
|             story_groups[story_id] = [] | ||||
|         story_groups[story_id].append(doc) | ||||
|  | ||||
|     # Reconstruct complete stories or use the best chunk per story | ||||
|     complete_stories = [] | ||||
|     for story_id, chunks in story_groups.items(): | ||||
|         if len(chunks) == 1: | ||||
|             complete_stories.append(chunks[0]) | ||||
|         else: | ||||
|             # Combine chunks back into complete story | ||||
|             combined_content = "\n\n".join( | ||||
|                 chunk.page_content | ||||
|                 for chunk in sorted( | ||||
|                     chunks, key=lambda x: x.metadata.get("chunk_index", 0) | ||||
|                 ) | ||||
|             ) | ||||
|             complete_story = langchain_core.documents.Document( | ||||
|                 page_content=combined_content, | ||||
|                 metadata=chunks[0].metadata,  # Use metadata from first chunk | ||||
|             ) | ||||
|             complete_stories.append(complete_story) | ||||
|  | ||||
|     return {"context": complete_stories[:5]}  # Limit to top 5 stories | ||||
|  | ||||
|  | ||||
| def generate(state: State): | ||||
|     docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | ||||
|     prompt = langchain.hub.pull("rlm/rag-prompt") | ||||
|     messages = prompt.invoke({"question": state["question"], "context": docs_content}) | ||||
|  | ||||
|     prompt = langchain.prompts.PromptTemplate( | ||||
|         input_variables=["preferences", "context"], | ||||
|         template=( | ||||
|             "You are a helpful assistant that can provide updates on technology topics based on the topics a user has expressed interest in and additional context.\n\n" | ||||
|             "Please respond in Markdown format and group your answers based on the categories of the items in the context.\n" | ||||
|             "If applicable, add hyperlinks to the original source as part of the headline for each story.\n" | ||||
|             "Limit your summaries to approximately 100 words per item.\n\n" | ||||
|             "Preferences: {preferences}\n\n" | ||||
|             "Context:\n{context}\n\n" | ||||
|             "Answer:" | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     messages = prompt.invoke( | ||||
|         {"preferences": state["preferences"], "context": docs_content} | ||||
|     ) | ||||
|     response = llm.invoke(messages) | ||||
|     return {"answer": response.content} | ||||
|  | ||||
|  | ||||
| def run_query(question: str): | ||||
| def run_query(preferences: Iterable[str]): | ||||
|     graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate]) | ||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve") | ||||
|     graph = graph_builder.compile() | ||||
|  | ||||
|     response = graph.invoke({"question": question}) | ||||
|     response = graph.invoke(State(preferences=preferences, context=[], answer="")) | ||||
|     print(response["answer"]) | ||||
|  | ||||
|  | ||||
| def get_existing_story_ids() -> set[str]: | ||||
|     """Get the IDs of stories that already exist in the vector store.""" | ||||
|     try: | ||||
|         collection = vector_store._collection | ||||
|         existing_ids = set() | ||||
|         for doc in collection.iterator(): | ||||
|             story_id = doc.properties.get("story_id") | ||||
|             if story_id: | ||||
|                 existing_ids.add(story_id) | ||||
|         return existing_ids | ||||
|     except Exception: | ||||
|         logging.warning("Could not retrieve existing story IDs", exc_info=True) | ||||
|         return set() | ||||
|  | ||||
|  | ||||
| async def fetch_hn_top_stories( | ||||
|     limit: int = 10, | ||||
| ) -> list[langchain_core.documents.Document]: | ||||
|     hn = HackerNewsClient() | ||||
|     stories = hn.get_top_stories(limit=limit) | ||||
|     stories = await hn.get_top_stories(limit=limit) | ||||
|  | ||||
|     # Get existing story IDs to avoid re-fetching | ||||
|     existing_ids = get_existing_story_ids() | ||||
|     logging.info(f"Existing story IDs: {len(existing_ids)} found in vector store") | ||||
|  | ||||
|     new_stories = [story for story in stories if story.id not in existing_ids] | ||||
|  | ||||
|     print(f"Found {len(stories)} top stories, {len(new_stories)} are new") | ||||
|  | ||||
|     if not new_stories: | ||||
|         print("No new stories to fetch") | ||||
|         return [] | ||||
|  | ||||
|     contents = {} | ||||
|  | ||||
|     # Fetch content for each story asynchronously | ||||
|     # Fetch content for each new story asynchronously | ||||
|     scraper = JinaScraper(os.getenv("JINA_API_KEY")) | ||||
|  | ||||
|     async def _fetch_content(story: Story) -> tuple[str, str]: | ||||
|         if not story.url: | ||||
|             return story.id, story.title | ||||
|         return story.id, await scraper.get_content(story.url) | ||||
|         try: | ||||
|             if not story.url: | ||||
|                 return story.id, story.title | ||||
|             return story.id, await scraper.get_content(story.url) | ||||
|         except Exception as e: | ||||
|             logging.warning(f"Failed to fetch content for story {story.id}: {e}") | ||||
|             return story.id, story.title  # Fallback to title if content fetch fails | ||||
|  | ||||
|     tasks = [_fetch_content(story) for story in stories] | ||||
|     results = await asyncio.gather(*tasks) | ||||
|     contents = dict(results) | ||||
|     tasks = [_fetch_content(story) for story in new_stories] | ||||
|     results = await asyncio.gather(*tasks, return_exceptions=True) | ||||
|  | ||||
|     # Filter out exceptions and convert to dict | ||||
|     contents = {} | ||||
|     for result in results: | ||||
|         if isinstance(result, Exception): | ||||
|             logging.error(f"Task failed with exception: {result}") | ||||
|             continue | ||||
|         if isinstance(result, tuple) and len(result) == 2: | ||||
|             story_id, content = result | ||||
|             contents[story_id] = content | ||||
|  | ||||
|     documents = [ | ||||
|         langchain_core.documents.Document( | ||||
|             page_content=contents[story.id], | ||||
|             metadata={ | ||||
|                 "id": story.id, | ||||
|                 "story_id": story.id, | ||||
|                 "title": story.title, | ||||
|                 "source": story.url, | ||||
|                 "created_at": story.created_at.isoformat(), | ||||
|             }, | ||||
|         ) | ||||
|         for story in stories | ||||
|         for story in new_stories | ||||
|     ] | ||||
|     return documents | ||||
|  | ||||
|  | ||||
| async def main(): | ||||
|     import mlflow | ||||
|  | ||||
|     mlflow.set_tracking_uri("http://localhost:5000") | ||||
|     mlflow.set_experiment("langchain-rag-hn") | ||||
|     mlflow.langchain.autolog() | ||||
|  | ||||
|     # 1. Load | ||||
|     stories = await fetch_hn_top_stories(limit=3) | ||||
|     # 1. Load only new stories | ||||
|     new_stories = await fetch_hn_top_stories(limit=20) | ||||
|  | ||||
|     # 2. Split | ||||
|     splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||
|         chunk_size=1000, chunk_overlap=200 | ||||
|     ) | ||||
|     all_splits = splitter.split_documents(stories) | ||||
|     if new_stories: | ||||
|         print(f"Processing {len(new_stories)} new stories") | ||||
|  | ||||
|     # 3. Store | ||||
|     _ = vector_store.add_documents(all_splits) | ||||
|         # Categorize stories (optional) | ||||
|         from classify import categorize | ||||
|  | ||||
|         for story in new_stories: | ||||
|             categories = categorize(story, llm) | ||||
|             story.metadata["categories"] = list(categories) | ||||
|             print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}") | ||||
|  | ||||
|         # 2. Split | ||||
|         documents_to_store = [] | ||||
|         for story in new_stories: | ||||
|             # If article is short enough, store as-is | ||||
|             if len(story.page_content) <= 3000:  # Adjust threshold as needed | ||||
|                 documents_to_store.append(story) | ||||
|             else: | ||||
|                 # For very long articles, chunk but keep story metadata | ||||
|                 splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||
|                     chunk_size=2000, | ||||
|                     chunk_overlap=200, | ||||
|                     add_start_index=True, | ||||
|                 ) | ||||
|                 chunks = splitter.split_documents([story]) | ||||
|                 # Add chunk info to metadata | ||||
|                 for i, chunk in enumerate(chunks): | ||||
|                     chunk.metadata["chunk_index"] = i | ||||
|                     chunk.metadata["total_chunks"] = len(chunks) | ||||
|                 documents_to_store.extend(chunks) | ||||
|  | ||||
|         # 3. Store | ||||
|         _ = vector_store.add_documents(documents_to_store) | ||||
|         print(f"Added {len(documents_to_store)} documents to vector store") | ||||
|     else: | ||||
|         print("No new stories to process") | ||||
|  | ||||
|     # 4. Query | ||||
|     question = "What are the top stories related to AI and Machine Learning right now?" | ||||
|     run_query(question) | ||||
|     preferences = ["Software Engineering", "Machine Learning", "Games"] | ||||
|     run_query(preferences) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     import asyncio | ||||
|  | ||||
|     asyncio.run(main()) | ||||
|     logging.basicConfig(level=logging.INFO) | ||||
|  | ||||
|     try: | ||||
|         asyncio.run(main()) | ||||
|     finally: | ||||
|         weaviate_client.close() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user