feat: Reworked chunking and retrieval logic to operate on entire stories instead of chunks.

This commit is contained in:
Adrian Rumpold
2025-07-01 13:08:45 +02:00
parent f093a488f3
commit b55fd6a021
8 changed files with 442 additions and 920 deletions

4
.gitignore vendored
View File

@@ -13,5 +13,5 @@ wheels/
mlruns/ mlruns/
mlartifacts/ mlartifacts/
# ChromaDB # Weaviate vector store
chroma_db/ weaviate/

21
classify.py Normal file
View File

@@ -0,0 +1,21 @@
from langchain_core.documents import Document
from langchain_core.language_models import BaseChatModel
def categorize(doc: Document, llm: BaseChatModel) -> set[str]:
# Create a prompt for category extraction
prompt = f"""
Extract up to 3 relevant categories from the following document.
Return only the category names as a list of JSON strings.
If you cannot find any relevant categories, return an empty list.
Title: {doc.metadata.get('title', 'No title')}
Content: {doc.page_content}...
Categories:"""
# Get response from LLM
result = llm.with_structured_output(method="json_mode").invoke(prompt)
categories = result.get("categories", [])
return set(categories)

25
compose.yml Normal file
View File

@@ -0,0 +1,25 @@
services:
weaviate:
command:
- --host
- 0.0.0.0
- --port
- "8080"
- --scheme
- http
image: cr.weaviate.io/semitechnologies/weaviate:1.31.4
ports:
- 8080:8080
- 50051:50051
volumes:
- ./weaviate:/var/lib/weaviate
restart: on-failure
environment:
QUERY_DEFAULTS_LIMIT: 25
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: "true"
PERSISTENCE_DATA_PATH: "/var/lib/weaviate"
ENABLE_API_BASED_MODULES: "true"
ENABLE_MODULES: "text2vec-openai,generative-openai"
CLUSTER_HOSTNAME: "node1"
OPENAI_APIKEY: ${OPENAI_API_KEY}
DISABLE_TELEMETRY: "true"

30
hn.py
View File

@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime from datetime import datetime
import dateutil.parser import dateutil.parser
@@ -16,16 +17,16 @@ class Story(BaseModel):
class HackerNewsClient: class HackerNewsClient:
def __init__(self, client: httpx.Client | None = None): def __init__(self, client: httpx.AsyncClient | None = None):
base_url = "https://hn.algolia.com/api/v1" base_url = "https://hn.algolia.com/api/v1"
if client: if client:
client.base_url = base_url client.base_url = base_url
self._client = client self._client = client
else: else:
self._client = httpx.Client(base_url=base_url) self._client = httpx.AsyncClient(base_url=base_url)
def get_top_stories(self, limit: int = 10) -> list[Story]: async def get_top_stories(self, limit: int = 10) -> list[Story]:
resp = self._client.get( resp = await self._client.get(
"search", "search",
params={"tags": "front_page", "hitsPerPage": limit, "page": 0}, params={"tags": "front_page", "hitsPerPage": limit, "page": 0},
) )
@@ -40,7 +41,24 @@ class HackerNewsClient:
for hit in resp.json().get("hits", []) for hit in resp.json().get("hits", [])
] ]
def get_item(self, item_id): async def get_item(self, item_id):
resp = self._client.get(f"items/{item_id}") resp = await self._client.get(f"items/{item_id}")
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
async def close(self):
"""Close the underlying HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()
def __del__(self):
"""Ensure the HTTP client is closed when the object is deleted."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
except Exception:
pass

View File

@@ -1,119 +1,242 @@
import logging
import os import os
from typing import TypedDict from typing import Iterable, TypedDict
import langchain import langchain
import langchain.chat_models import langchain.chat_models
import langchain.hub import langchain.prompts
import langchain.text_splitter import langchain.text_splitter
import langchain_chroma
import langchain_core import langchain_core
import langchain_core.documents import langchain_core.documents
import langchain_openai import langchain_openai
import langchain_weaviate
import langgraph.graph import langgraph.graph
import mlflow
import weaviate
from hn import HackerNewsClient, Story from hn import HackerNewsClient, Story
from scrape import JinaScraper from scrape import JinaScraper
llm = langchain.chat_models.init_chat_model( llm = langchain.chat_models.init_chat_model(
model="gpt-4.1-nano", model_provider="openai" model="gpt-4o-mini", model_provider="openai"
) )
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small") embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-large")
vector_store = langchain_chroma.Chroma(
collection_name="hn_stories", weaviate_client = weaviate.connect_to_local()
embedding_function=embeddings, vector_store = langchain_weaviate.WeaviateVectorStore(
persist_directory="./chroma_db", weaviate_client,
create_collection_if_not_exists=True, index_name="hn_stories",
text_key="page_content",
embedding=embeddings,
) )
class State(TypedDict): class State(TypedDict):
question: str preferences: Iterable[str]
context: list[langchain_core.documents.Document] context: list[langchain_core.documents.Document]
answer: str answer: str
# Define application steps
def retrieve(state: State): def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"], k=10) # Search for relevant documents
return {"context": retrieved_docs} retrieved_docs = vector_store.similarity_search(
"Categories: " + ", ".join(state["preferences"]), k=10
)
# If you're using chunks, group them back into complete stories
story_groups = {}
for doc in retrieved_docs:
story_id = doc.metadata.get("story_id")
if story_id not in story_groups:
story_groups[story_id] = []
story_groups[story_id].append(doc)
# Reconstruct complete stories or use the best chunk per story
complete_stories = []
for story_id, chunks in story_groups.items():
if len(chunks) == 1:
complete_stories.append(chunks[0])
else:
# Combine chunks back into complete story
combined_content = "\n\n".join(
chunk.page_content
for chunk in sorted(
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
)
)
complete_story = langchain_core.documents.Document(
page_content=combined_content,
metadata=chunks[0].metadata, # Use metadata from first chunk
)
complete_stories.append(complete_story)
return {"context": complete_stories[:5]} # Limit to top 5 stories
def generate(state: State): def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) docs_content = "\n\n".join(doc.page_content for doc in state["context"])
prompt = langchain.hub.pull("rlm/rag-prompt")
messages = prompt.invoke({"question": state["question"], "context": docs_content}) prompt = langchain.prompts.PromptTemplate(
input_variables=["preferences", "context"],
template=(
"You are a helpful assistant that can provide updates on technology topics based on the topics a user has expressed interest in and additional context.\n\n"
"Please respond in Markdown format and group your answers based on the categories of the items in the context.\n"
"If applicable, add hyperlinks to the original source as part of the headline for each story.\n"
"Limit your summaries to approximately 100 words per item.\n\n"
"Preferences: {preferences}\n\n"
"Context:\n{context}\n\n"
"Answer:"
),
)
messages = prompt.invoke(
{"preferences": state["preferences"], "context": docs_content}
)
response = llm.invoke(messages) response = llm.invoke(messages)
return {"answer": response.content} return {"answer": response.content}
def run_query(question: str): def run_query(preferences: Iterable[str]):
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate]) graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(langgraph.graph.START, "retrieve") graph_builder.add_edge(langgraph.graph.START, "retrieve")
graph = graph_builder.compile() graph = graph_builder.compile()
response = graph.invoke({"question": question}) response = graph.invoke(State(preferences=preferences, context=[], answer=""))
print(response["answer"]) print(response["answer"])
def get_existing_story_ids() -> set[str]:
"""Get the IDs of stories that already exist in the vector store."""
try:
collection = vector_store._collection
existing_ids = set()
for doc in collection.iterator():
story_id = doc.properties.get("story_id")
if story_id:
existing_ids.add(story_id)
return existing_ids
except Exception:
logging.warning("Could not retrieve existing story IDs", exc_info=True)
return set()
async def fetch_hn_top_stories( async def fetch_hn_top_stories(
limit: int = 10, limit: int = 10,
) -> list[langchain_core.documents.Document]: ) -> list[langchain_core.documents.Document]:
hn = HackerNewsClient() hn = HackerNewsClient()
stories = hn.get_top_stories(limit=limit) stories = await hn.get_top_stories(limit=limit)
# Get existing story IDs to avoid re-fetching
existing_ids = get_existing_story_ids()
logging.info(f"Existing story IDs: {len(existing_ids)} found in vector store")
new_stories = [story for story in stories if story.id not in existing_ids]
print(f"Found {len(stories)} top stories, {len(new_stories)} are new")
if not new_stories:
print("No new stories to fetch")
return []
contents = {} contents = {}
# Fetch content for each story asynchronously # Fetch content for each new story asynchronously
scraper = JinaScraper(os.getenv("JINA_API_KEY")) scraper = JinaScraper(os.getenv("JINA_API_KEY"))
async def _fetch_content(story: Story) -> tuple[str, str]: async def _fetch_content(story: Story) -> tuple[str, str]:
if not story.url: try:
return story.id, story.title if not story.url:
return story.id, await scraper.get_content(story.url) return story.id, story.title
return story.id, await scraper.get_content(story.url)
except Exception as e:
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
return story.id, story.title # Fallback to title if content fetch fails
tasks = [_fetch_content(story) for story in stories] tasks = [_fetch_content(story) for story in new_stories]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks, return_exceptions=True)
contents = dict(results)
# Filter out exceptions and convert to dict
contents = {}
for result in results:
if isinstance(result, Exception):
logging.error(f"Task failed with exception: {result}")
continue
if isinstance(result, tuple) and len(result) == 2:
story_id, content = result
contents[story_id] = content
documents = [ documents = [
langchain_core.documents.Document( langchain_core.documents.Document(
page_content=contents[story.id], page_content=contents[story.id],
metadata={ metadata={
"id": story.id, "story_id": story.id,
"title": story.title, "title": story.title,
"source": story.url, "source": story.url,
"created_at": story.created_at.isoformat(), "created_at": story.created_at.isoformat(),
}, },
) )
for story in stories for story in new_stories
] ]
return documents return documents
async def main(): async def main():
import mlflow
mlflow.set_tracking_uri("http://localhost:5000") mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("langchain-rag-hn") mlflow.set_experiment("langchain-rag-hn")
mlflow.langchain.autolog() mlflow.langchain.autolog()
# 1. Load # 1. Load only new stories
stories = await fetch_hn_top_stories(limit=3) new_stories = await fetch_hn_top_stories(limit=20)
# 2. Split if new_stories:
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( print(f"Processing {len(new_stories)} new stories")
chunk_size=1000, chunk_overlap=200
)
all_splits = splitter.split_documents(stories)
# 3. Store # Categorize stories (optional)
_ = vector_store.add_documents(all_splits) from classify import categorize
for story in new_stories:
categories = categorize(story, llm)
story.metadata["categories"] = list(categories)
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}")
# 2. Split
documents_to_store = []
for story in new_stories:
# If article is short enough, store as-is
if len(story.page_content) <= 3000: # Adjust threshold as needed
documents_to_store.append(story)
else:
# For very long articles, chunk but keep story metadata
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
chunk_size=2000,
chunk_overlap=200,
add_start_index=True,
)
chunks = splitter.split_documents([story])
# Add chunk info to metadata
for i, chunk in enumerate(chunks):
chunk.metadata["chunk_index"] = i
chunk.metadata["total_chunks"] = len(chunks)
documents_to_store.extend(chunks)
# 3. Store
_ = vector_store.add_documents(documents_to_store)
print(f"Added {len(documents_to_store)} documents to vector store")
else:
print("No new stories to process")
# 4. Query # 4. Query
question = "What are the top stories related to AI and Machine Learning right now?" preferences = ["Software Engineering", "Machine Learning", "Games"]
run_query(question) run_query(preferences)
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) logging.basicConfig(level=logging.INFO)
try:
asyncio.run(main())
finally:
weaviate_client.close()

View File

@@ -8,7 +8,7 @@ dependencies = [
"hackernews>=2.0.0", "hackernews>=2.0.0",
"html2text>=2025.4.15", "html2text>=2025.4.15",
"httpx>=0.28.1", "httpx>=0.28.1",
"langchain-chroma>=0.2.4", "langchain-weaviate>=0.0.5",
"langchain[openai]>=0.3.26", "langchain[openai]>=0.3.26",
"langgraph>=0.5.0", "langgraph>=0.5.0",
"mlflow>=3.1.1", "mlflow>=3.1.1",

View File

@@ -1,3 +1,5 @@
import asyncio
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import override from typing import override
@@ -6,17 +8,41 @@ import httpx
class TextScraper(ABC): class TextScraper(ABC):
def __init__(self): def __init__(self):
self._client = httpx.AsyncClient() self._client = httpx.AsyncClient(timeout=httpx.Timeout(5.0))
async def _fetch_text(self, url: str) -> str: async def _fetch_text(self, url: str) -> str:
"""Fetch the raw HTML content from the URL.""" """Fetch the raw HTML content from the URL."""
response = await self._client.get(url) response = None
response.raise_for_status() try:
return response.text response = await self._client.get(url)
response.raise_for_status()
return response.text
except Exception:
logging.warning(f"Failed to fetch text from {url}", exc_info=True)
raise
finally:
if response:
await response.aclose()
@abstractmethod @abstractmethod
async def get_content(self, url: str) -> str: ... async def get_content(self, url: str) -> str: ...
async def close(self):
"""Close the underlying HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()
def __del__(self):
"""Ensure the HTTP client is closed when the object is deleted."""
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
except Exception:
pass
class Html2textScraper(TextScraper): class Html2textScraper(TextScraper):
@override @override
@@ -39,12 +65,13 @@ class JinaScraper(TextScraper):
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
super().__init__() super().__init__()
if api_key: if api_key:
self._client.headers.update({"Authorization": "Bearer {api_key}"}) self._client.headers.update({"Authorization": f"Bearer {api_key}"})
@override @override
async def get_content(self, url: str) -> str: async def get_content(self, url: str) -> str:
print(f"Fetching content from: {url}") print(f"Fetching content from: {url}")
try: try:
return await self._fetch_text(f"https://r.jina.ai/{url}") return await self._fetch_text(f"https://r.jina.ai/{url}")
except httpx.HTTPStatusError: except Exception:
logging.warning(f"Failed to fetch content from {url}", exc_info=True)
return "" return ""

1036
uv.lock generated

File diff suppressed because it is too large Load Diff