Compare commits

...

2 Commits

Author SHA1 Message Date
Adrian Rumpold
311c332b10 feat: Use langchain parent document retriever to simplify retrieval logic 2025-07-02 12:22:50 +02:00
Adrian Rumpold
259c9699ad fix: Fix aio resource leaks 2025-07-02 12:22:05 +02:00
5 changed files with 107 additions and 137 deletions

2
.gitignore vendored
View File

@@ -15,3 +15,5 @@ mlartifacts/
# Weaviate vector store # Weaviate vector store
weaviate/ weaviate/
data/

View File

@@ -1,3 +1,4 @@
import asyncio
import logging import logging
import os import os
from typing import Iterable, TypedDict from typing import Iterable, TypedDict
@@ -5,6 +6,7 @@ from typing import Iterable, TypedDict
import langchain import langchain
import langchain.chat_models import langchain.chat_models
import langchain.prompts import langchain.prompts
import langchain.retrievers
import langchain.text_splitter import langchain.text_splitter
import langchain_core import langchain_core
import langchain_core.documents import langchain_core.documents
@@ -16,11 +18,12 @@ import slack
import weaviate import weaviate
from hn import HackerNewsClient, Story from hn import HackerNewsClient, Story
from scrape import JinaScraper from scrape import JinaScraper
from util import DocumentLocalFileStore
NUM_STORIES = 20 NUM_STORIES = 40 # Number of top stories to fetch from Hacker News
USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"] USER_PREFERENCES = ["Machine Learning", "Programming", "DevOps"]
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing ENABLE_MLFLOW_TRACING = True # Use MLflow (at http://localhost:5000) for tracing
llm = langchain.chat_models.init_chat_model( llm = langchain.chat_models.init_chat_model(
@@ -36,6 +39,18 @@ vector_store = langchain_weaviate.WeaviateVectorStore(
embedding=embeddings, embedding=embeddings,
) )
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
)
doc_store = DocumentLocalFileStore(root_path="data/documents")
retriever = langchain.retrievers.ParentDocumentRetriever(
vectorstore=vector_store,
docstore=doc_store,
child_splitter=splitter,
)
class State(TypedDict): class State(TypedDict):
preferences: Iterable[str] preferences: Iterable[str]
@@ -44,70 +59,15 @@ class State(TypedDict):
summaries: list[dict] summaries: list[dict]
def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State: def retrieve_docs(state: State, top_n: int = 3):
# Search for relevant documents (with scores if available) """Retrieve relevant documents based on user preferences."""
retrieved_docs = vector_store.similarity_search(
"Show the most interesting articles about the following topics: "
+ ", ".join(state["preferences"]),
k=top_n * 20, # Chunks, not complete stories
return_score=True
if hasattr(vector_store, "similarity_search_with_score")
else False,
)
# If scores are returned, unpack (doc, score) tuples; else, set score to None docs = []
docs_with_scores = [] for preference in state["preferences"]:
if retrieved_docs and isinstance(retrieved_docs[0], tuple): logging.info(f"Retrieving documents for preference: {preference}")
for doc, score in retrieved_docs: docs.extend(retriever.invoke(preference)[:top_n])
docs_with_scores.append((doc, score))
else:
for doc in retrieved_docs:
docs_with_scores.append((doc, None))
# Group chunks by story_id and collect their scores return {"context": docs}
story_groups = {}
for doc, score in docs_with_scores:
story_id = doc.metadata.get("story_id")
if story_id not in story_groups:
story_groups[story_id] = []
story_groups[story_id].append((doc, score))
# Aggregate max score per story and reconstruct complete stories
story_scores = {}
complete_stories = []
for story_id, chunks_scores in story_groups.items():
chunks = [doc for doc, _ in chunks_scores]
scores = [s for _, s in chunks_scores if s is not None]
max_score = max(scores) if scores else None
story_scores[story_id] = max_score
if len(chunks) == 1:
complete_stories.append((chunks[0], max_score))
else:
combined_content = "\n\n".join(
chunk.page_content
for chunk in sorted(
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
)
)
complete_story = langchain_core.documents.Document(
page_content=combined_content,
metadata=chunks[0].metadata, # Use metadata from first chunk
)
complete_stories.append((complete_story, max_score))
# Sort stories by max_score descending (None scores go last)
complete_stories_sorted = sorted(
complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True
)
# Return top_n stories
top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]]
return {
"preferences": state["preferences"],
"context": top_stories,
"answer": state.get("answer", ""),
"summaries": state.get("summaries", []),
}
def generate_structured_summaries(state: State): def generate_structured_summaries(state: State):
@@ -217,9 +177,9 @@ def run_structured_query(
) -> list[dict]: ) -> list[dict]:
"""Run query and return structured summary data.""" """Run query and return structured summary data."""
graph_builder = langgraph.graph.StateGraph(State).add_sequence( graph_builder = langgraph.graph.StateGraph(State).add_sequence(
[retrieve, generate_structured_summaries] [retrieve_docs, generate_structured_summaries]
) )
graph_builder.add_edge(langgraph.graph.START, "retrieve") graph_builder.add_edge(langgraph.graph.START, "retrieve_docs")
graph = graph_builder.compile() graph = graph_builder.compile()
response = graph.invoke( response = graph.invoke(
@@ -246,6 +206,7 @@ def get_existing_story_ids() -> set[str]:
async def fetch_hn_top_stories( async def fetch_hn_top_stories(
limit: int = 10, limit: int = 10,
force_fetch: bool = False,
) -> list[langchain_core.documents.Document]: ) -> list[langchain_core.documents.Document]:
hn = HackerNewsClient() hn = HackerNewsClient()
stories = await hn.get_top_stories(limit=limit) stories = await hn.get_top_stories(limit=limit)
@@ -256,11 +217,15 @@ async def fetch_hn_top_stories(
new_stories = [story for story in stories if story.id not in existing_ids] new_stories = [story for story in stories if story.id not in existing_ids]
print(f"Found {len(stories)} top stories, {len(new_stories)} are new") logging.info(f"Found {len(stories)} top stories, {len(new_stories)} are new")
if not new_stories: if not new_stories:
print("No new stories to fetch") if not force_fetch:
return [] logging.info("No new stories to fetch")
return []
else:
logging.info("Force fetching all top stories regardless of existing IDs")
new_stories = stories
contents = {} contents = {}
@@ -268,26 +233,18 @@ async def fetch_hn_top_stories(
scraper = JinaScraper(os.getenv("JINA_API_KEY")) scraper = JinaScraper(os.getenv("JINA_API_KEY"))
async def _fetch_content(story: Story) -> tuple[str, str]: async def _fetch_content(story: Story) -> tuple[str, str]:
try: if not story.url:
if not story.url: return story.id, story.title
return story.id, story.title return story.id, await scraper.get_content(story.url)
return story.id, await scraper.get_content(story.url)
except Exception as e:
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
return story.id, story.title # Fallback to title if content fetch fails
tasks = [_fetch_content(story) for story in new_stories] tasks = [_fetch_content(story) for story in new_stories]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks)
# Filter out exceptions and convert to dict # Filter out exceptions and convert to dict
contents = {} contents = {}
for result in results: for result in results:
if isinstance(result, Exception): story_id, content = result
logging.error(f"Task failed with exception: {result}") contents[story_id] = content
continue
if isinstance(result, tuple) and len(result) == 2:
story_id, content = result
contents[story_id] = content
documents = [ documents = [
langchain_core.documents.Document( langchain_core.documents.Document(
@@ -324,31 +281,15 @@ async def main():
for story in new_stories: for story in new_stories:
categories = categorize(story, llm) categories = categorize(story, llm)
story.metadata["categories"] = list(categories) story.metadata["categories"] = list(categories)
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}") print(
f"Story ID {story.metadata["story_id"]} ({story.metadata["title"]}) categorized as: {categories}"
)
# 2. Split # 2. Split & 3. Store
documents_to_store = [] retriever.add_documents(
for story in new_stories: new_stories, ids=[doc.metadata["story_id"] for doc in new_stories]
# If article is short enough, store as-is )
if len(story.page_content) <= 3000: print(f"Added {len(new_stories)} documents to document store")
documents_to_store.append(story)
else:
# For very long articles, chunk but keep story metadata
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
add_start_index=True,
)
chunks = splitter.split_documents([story])
# Add chunk info to metadata
for i, chunk in enumerate(chunks):
chunk.metadata["chunk_index"] = i
chunk.metadata["total_chunks"] = len(chunks)
documents_to_store.extend(chunks)
# 3. Store
_ = vector_store.add_documents(documents_to_store)
print(f"Added {len(documents_to_store)} documents to vector store")
else: else:
print("No new stories to process") print("No new stories to process")
@@ -361,8 +302,6 @@ async def main():
if __name__ == "__main__": if __name__ == "__main__":
import asyncio
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
try: try:

View File

@@ -1,4 +1,3 @@
import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import override from typing import override
@@ -8,41 +7,22 @@ import httpx
class TextScraper(ABC): class TextScraper(ABC):
def __init__(self): def __init__(self):
self._client = httpx.AsyncClient(timeout=httpx.Timeout(5.0)) self._http_headers = {}
async def _fetch_text(self, url: str) -> str: async def _fetch_text(self, url: str) -> str:
"""Fetch the raw HTML content from the URL.""" """Fetch the raw HTML content from the URL."""
response = None
try: try:
response = await self._client.get(url) async with httpx.AsyncClient(headers=self._http_headers) as client:
response.raise_for_status() response = await client.get(url)
return response.text response.raise_for_status()
return response.text
except Exception: except Exception:
logging.warning(f"Failed to fetch text from {url}", exc_info=True) logging.warning(f"Failed to fetch text from {url}", exc_info=True)
raise raise
finally:
if response:
await response.aclose()
@abstractmethod @abstractmethod
async def get_content(self, url: str) -> str: ... async def get_content(self, url: str) -> str: ...
async def close(self):
"""Close the underlying HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()
def __del__(self):
"""Ensure the HTTP client is closed when the object is deleted."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
except Exception:
pass
class Html2textScraper(TextScraper): class Html2textScraper(TextScraper):
@override @override
@@ -65,7 +45,7 @@ class JinaScraper(TextScraper):
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
super().__init__() super().__init__()
if api_key: if api_key:
self._client.headers.update({"Authorization": f"Bearer {api_key}"}) self._http_headers.update({"Authorization": f"Bearer {api_key}"})
@override @override
async def get_content(self, url: str) -> str: async def get_content(self, url: str) -> str:

View File

@@ -77,6 +77,7 @@ def send_message(channel: str, blocks: list) -> None:
text="Tech updates", text="Tech updates",
blocks=blocks, blocks=blocks,
unfurl_links=False, unfurl_links=False,
unfurl_media=False,
) )
response.validate() response.validate()
logging.info(f"Message sent successfully to channel {channel}") logging.info(f"Message sent successfully to channel {channel}")

48
util.py Normal file
View File

@@ -0,0 +1,48 @@
import json
from collections.abc import Iterator
from langchain.schema import BaseStore, Document
from langchain.storage import LocalFileStore
class DocumentLocalFileStore(BaseStore[str, Document]):
def __init__(self, root_path: str):
self.store = LocalFileStore(root_path)
def mget(self, keys):
"""Get multiple documents by keys"""
results = []
for key in keys:
try:
serialized = self.store.mget([key])[0]
if serialized:
# Deserialize the document
doc_dict = json.loads(serialized.decode("utf-8"))
doc = Document(
page_content=doc_dict["page_content"],
metadata=doc_dict.get("metadata", {}),
)
results.append(doc)
else:
results.append(None)
except:
results.append(None)
return results
def mset(self, key_value_pairs):
"""Set multiple key-value pairs"""
serialized_pairs = []
for key, doc in key_value_pairs:
# Serialize the document
doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata}
serialized = json.dumps(doc_dict).encode("utf-8")
serialized_pairs.append((key, serialized))
self.store.mset(serialized_pairs)
def mdelete(self, keys):
"""Delete multiple keys"""
self.store.mdelete(keys)
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
return self.store.yield_keys(prefix=prefix)