feat: Reworked chunking and retrieval logic to operate on entire stories instead of chunks.
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -13,5 +13,5 @@ wheels/
|
||||
mlruns/
|
||||
mlartifacts/
|
||||
|
||||
# ChromaDB
|
||||
chroma_db/
|
||||
# Weaviate vector store
|
||||
weaviate/
|
||||
21
classify.py
Normal file
21
classify.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
|
||||
def categorize(doc: Document, llm: BaseChatModel) -> set[str]:
|
||||
# Create a prompt for category extraction
|
||||
prompt = f"""
|
||||
Extract up to 3 relevant categories from the following document.
|
||||
Return only the category names as a list of JSON strings.
|
||||
|
||||
If you cannot find any relevant categories, return an empty list.
|
||||
|
||||
Title: {doc.metadata.get('title', 'No title')}
|
||||
Content: {doc.page_content}...
|
||||
|
||||
Categories:"""
|
||||
|
||||
# Get response from LLM
|
||||
result = llm.with_structured_output(method="json_mode").invoke(prompt)
|
||||
categories = result.get("categories", [])
|
||||
return set(categories)
|
||||
25
compose.yml
Normal file
25
compose.yml
Normal file
@@ -0,0 +1,25 @@
|
||||
services:
|
||||
weaviate:
|
||||
command:
|
||||
- --host
|
||||
- 0.0.0.0
|
||||
- --port
|
||||
- "8080"
|
||||
- --scheme
|
||||
- http
|
||||
image: cr.weaviate.io/semitechnologies/weaviate:1.31.4
|
||||
ports:
|
||||
- 8080:8080
|
||||
- 50051:50051
|
||||
volumes:
|
||||
- ./weaviate:/var/lib/weaviate
|
||||
restart: on-failure
|
||||
environment:
|
||||
QUERY_DEFAULTS_LIMIT: 25
|
||||
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: "true"
|
||||
PERSISTENCE_DATA_PATH: "/var/lib/weaviate"
|
||||
ENABLE_API_BASED_MODULES: "true"
|
||||
ENABLE_MODULES: "text2vec-openai,generative-openai"
|
||||
CLUSTER_HOSTNAME: "node1"
|
||||
OPENAI_APIKEY: ${OPENAI_API_KEY}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
30
hn.py
30
hn.py
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import dateutil.parser
|
||||
@@ -16,16 +17,16 @@ class Story(BaseModel):
|
||||
|
||||
|
||||
class HackerNewsClient:
|
||||
def __init__(self, client: httpx.Client | None = None):
|
||||
def __init__(self, client: httpx.AsyncClient | None = None):
|
||||
base_url = "https://hn.algolia.com/api/v1"
|
||||
if client:
|
||||
client.base_url = base_url
|
||||
self._client = client
|
||||
else:
|
||||
self._client = httpx.Client(base_url=base_url)
|
||||
self._client = httpx.AsyncClient(base_url=base_url)
|
||||
|
||||
def get_top_stories(self, limit: int = 10) -> list[Story]:
|
||||
resp = self._client.get(
|
||||
async def get_top_stories(self, limit: int = 10) -> list[Story]:
|
||||
resp = await self._client.get(
|
||||
"search",
|
||||
params={"tags": "front_page", "hitsPerPage": limit, "page": 0},
|
||||
)
|
||||
@@ -40,7 +41,24 @@ class HackerNewsClient:
|
||||
for hit in resp.json().get("hits", [])
|
||||
]
|
||||
|
||||
def get_item(self, item_id):
|
||||
resp = self._client.get(f"items/{item_id}")
|
||||
async def get_item(self, item_id):
|
||||
resp = await self._client.get(f"items/{item_id}")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def close(self):
|
||||
"""Close the underlying HTTP client."""
|
||||
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure the HTTP client is closed when the object is deleted."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
205
indexing.py
205
indexing.py
@@ -1,119 +1,242 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import TypedDict
|
||||
from typing import Iterable, TypedDict
|
||||
|
||||
import langchain
|
||||
import langchain.chat_models
|
||||
import langchain.hub
|
||||
import langchain.prompts
|
||||
import langchain.text_splitter
|
||||
import langchain_chroma
|
||||
import langchain_core
|
||||
import langchain_core.documents
|
||||
import langchain_openai
|
||||
import langchain_weaviate
|
||||
import langgraph.graph
|
||||
import mlflow
|
||||
|
||||
import weaviate
|
||||
from hn import HackerNewsClient, Story
|
||||
from scrape import JinaScraper
|
||||
|
||||
llm = langchain.chat_models.init_chat_model(
|
||||
model="gpt-4.1-nano", model_provider="openai"
|
||||
model="gpt-4o-mini", model_provider="openai"
|
||||
)
|
||||
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
|
||||
vector_store = langchain_chroma.Chroma(
|
||||
collection_name="hn_stories",
|
||||
embedding_function=embeddings,
|
||||
persist_directory="./chroma_db",
|
||||
create_collection_if_not_exists=True,
|
||||
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-large")
|
||||
|
||||
weaviate_client = weaviate.connect_to_local()
|
||||
vector_store = langchain_weaviate.WeaviateVectorStore(
|
||||
weaviate_client,
|
||||
index_name="hn_stories",
|
||||
text_key="page_content",
|
||||
embedding=embeddings,
|
||||
)
|
||||
|
||||
|
||||
class State(TypedDict):
|
||||
question: str
|
||||
preferences: Iterable[str]
|
||||
context: list[langchain_core.documents.Document]
|
||||
answer: str
|
||||
|
||||
|
||||
# Define application steps
|
||||
def retrieve(state: State):
|
||||
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
|
||||
return {"context": retrieved_docs}
|
||||
# Search for relevant documents
|
||||
retrieved_docs = vector_store.similarity_search(
|
||||
"Categories: " + ", ".join(state["preferences"]), k=10
|
||||
)
|
||||
|
||||
# If you're using chunks, group them back into complete stories
|
||||
story_groups = {}
|
||||
for doc in retrieved_docs:
|
||||
story_id = doc.metadata.get("story_id")
|
||||
if story_id not in story_groups:
|
||||
story_groups[story_id] = []
|
||||
story_groups[story_id].append(doc)
|
||||
|
||||
# Reconstruct complete stories or use the best chunk per story
|
||||
complete_stories = []
|
||||
for story_id, chunks in story_groups.items():
|
||||
if len(chunks) == 1:
|
||||
complete_stories.append(chunks[0])
|
||||
else:
|
||||
# Combine chunks back into complete story
|
||||
combined_content = "\n\n".join(
|
||||
chunk.page_content
|
||||
for chunk in sorted(
|
||||
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
|
||||
)
|
||||
)
|
||||
complete_story = langchain_core.documents.Document(
|
||||
page_content=combined_content,
|
||||
metadata=chunks[0].metadata, # Use metadata from first chunk
|
||||
)
|
||||
complete_stories.append(complete_story)
|
||||
|
||||
return {"context": complete_stories[:5]} # Limit to top 5 stories
|
||||
|
||||
|
||||
def generate(state: State):
|
||||
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||
prompt = langchain.hub.pull("rlm/rag-prompt")
|
||||
messages = prompt.invoke({"question": state["question"], "context": docs_content})
|
||||
|
||||
prompt = langchain.prompts.PromptTemplate(
|
||||
input_variables=["preferences", "context"],
|
||||
template=(
|
||||
"You are a helpful assistant that can provide updates on technology topics based on the topics a user has expressed interest in and additional context.\n\n"
|
||||
"Please respond in Markdown format and group your answers based on the categories of the items in the context.\n"
|
||||
"If applicable, add hyperlinks to the original source as part of the headline for each story.\n"
|
||||
"Limit your summaries to approximately 100 words per item.\n\n"
|
||||
"Preferences: {preferences}\n\n"
|
||||
"Context:\n{context}\n\n"
|
||||
"Answer:"
|
||||
),
|
||||
)
|
||||
|
||||
messages = prompt.invoke(
|
||||
{"preferences": state["preferences"], "context": docs_content}
|
||||
)
|
||||
response = llm.invoke(messages)
|
||||
return {"answer": response.content}
|
||||
|
||||
|
||||
def run_query(question: str):
|
||||
def run_query(preferences: Iterable[str]):
|
||||
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
|
||||
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
||||
graph = graph_builder.compile()
|
||||
|
||||
response = graph.invoke({"question": question})
|
||||
response = graph.invoke(State(preferences=preferences, context=[], answer=""))
|
||||
print(response["answer"])
|
||||
|
||||
|
||||
def get_existing_story_ids() -> set[str]:
|
||||
"""Get the IDs of stories that already exist in the vector store."""
|
||||
try:
|
||||
collection = vector_store._collection
|
||||
existing_ids = set()
|
||||
for doc in collection.iterator():
|
||||
story_id = doc.properties.get("story_id")
|
||||
if story_id:
|
||||
existing_ids.add(story_id)
|
||||
return existing_ids
|
||||
except Exception:
|
||||
logging.warning("Could not retrieve existing story IDs", exc_info=True)
|
||||
return set()
|
||||
|
||||
|
||||
async def fetch_hn_top_stories(
|
||||
limit: int = 10,
|
||||
) -> list[langchain_core.documents.Document]:
|
||||
hn = HackerNewsClient()
|
||||
stories = hn.get_top_stories(limit=limit)
|
||||
stories = await hn.get_top_stories(limit=limit)
|
||||
|
||||
# Get existing story IDs to avoid re-fetching
|
||||
existing_ids = get_existing_story_ids()
|
||||
logging.info(f"Existing story IDs: {len(existing_ids)} found in vector store")
|
||||
|
||||
new_stories = [story for story in stories if story.id not in existing_ids]
|
||||
|
||||
print(f"Found {len(stories)} top stories, {len(new_stories)} are new")
|
||||
|
||||
if not new_stories:
|
||||
print("No new stories to fetch")
|
||||
return []
|
||||
|
||||
contents = {}
|
||||
|
||||
# Fetch content for each story asynchronously
|
||||
# Fetch content for each new story asynchronously
|
||||
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
|
||||
|
||||
async def _fetch_content(story: Story) -> tuple[str, str]:
|
||||
if not story.url:
|
||||
return story.id, story.title
|
||||
return story.id, await scraper.get_content(story.url)
|
||||
try:
|
||||
if not story.url:
|
||||
return story.id, story.title
|
||||
return story.id, await scraper.get_content(story.url)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
|
||||
return story.id, story.title # Fallback to title if content fetch fails
|
||||
|
||||
tasks = [_fetch_content(story) for story in stories]
|
||||
results = await asyncio.gather(*tasks)
|
||||
contents = dict(results)
|
||||
tasks = [_fetch_content(story) for story in new_stories]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Filter out exceptions and convert to dict
|
||||
contents = {}
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
logging.error(f"Task failed with exception: {result}")
|
||||
continue
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
story_id, content = result
|
||||
contents[story_id] = content
|
||||
|
||||
documents = [
|
||||
langchain_core.documents.Document(
|
||||
page_content=contents[story.id],
|
||||
metadata={
|
||||
"id": story.id,
|
||||
"story_id": story.id,
|
||||
"title": story.title,
|
||||
"source": story.url,
|
||||
"created_at": story.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
for story in stories
|
||||
for story in new_stories
|
||||
]
|
||||
return documents
|
||||
|
||||
|
||||
async def main():
|
||||
import mlflow
|
||||
|
||||
mlflow.set_tracking_uri("http://localhost:5000")
|
||||
mlflow.set_experiment("langchain-rag-hn")
|
||||
mlflow.langchain.autolog()
|
||||
|
||||
# 1. Load
|
||||
stories = await fetch_hn_top_stories(limit=3)
|
||||
# 1. Load only new stories
|
||||
new_stories = await fetch_hn_top_stories(limit=20)
|
||||
|
||||
# 2. Split
|
||||
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
)
|
||||
all_splits = splitter.split_documents(stories)
|
||||
if new_stories:
|
||||
print(f"Processing {len(new_stories)} new stories")
|
||||
|
||||
# 3. Store
|
||||
_ = vector_store.add_documents(all_splits)
|
||||
# Categorize stories (optional)
|
||||
from classify import categorize
|
||||
|
||||
for story in new_stories:
|
||||
categories = categorize(story, llm)
|
||||
story.metadata["categories"] = list(categories)
|
||||
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}")
|
||||
|
||||
# 2. Split
|
||||
documents_to_store = []
|
||||
for story in new_stories:
|
||||
# If article is short enough, store as-is
|
||||
if len(story.page_content) <= 3000: # Adjust threshold as needed
|
||||
documents_to_store.append(story)
|
||||
else:
|
||||
# For very long articles, chunk but keep story metadata
|
||||
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
||||
chunk_size=2000,
|
||||
chunk_overlap=200,
|
||||
add_start_index=True,
|
||||
)
|
||||
chunks = splitter.split_documents([story])
|
||||
# Add chunk info to metadata
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk.metadata["chunk_index"] = i
|
||||
chunk.metadata["total_chunks"] = len(chunks)
|
||||
documents_to_store.extend(chunks)
|
||||
|
||||
# 3. Store
|
||||
_ = vector_store.add_documents(documents_to_store)
|
||||
print(f"Added {len(documents_to_store)} documents to vector store")
|
||||
else:
|
||||
print("No new stories to process")
|
||||
|
||||
# 4. Query
|
||||
question = "What are the top stories related to AI and Machine Learning right now?"
|
||||
run_query(question)
|
||||
preferences = ["Software Engineering", "Machine Learning", "Games"]
|
||||
run_query(preferences)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
try:
|
||||
asyncio.run(main())
|
||||
finally:
|
||||
weaviate_client.close()
|
||||
|
||||
@@ -8,7 +8,7 @@ dependencies = [
|
||||
"hackernews>=2.0.0",
|
||||
"html2text>=2025.4.15",
|
||||
"httpx>=0.28.1",
|
||||
"langchain-chroma>=0.2.4",
|
||||
"langchain-weaviate>=0.0.5",
|
||||
"langchain[openai]>=0.3.26",
|
||||
"langgraph>=0.5.0",
|
||||
"mlflow>=3.1.1",
|
||||
|
||||
39
scrape.py
39
scrape.py
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import override
|
||||
|
||||
@@ -6,17 +8,41 @@ import httpx
|
||||
|
||||
class TextScraper(ABC):
|
||||
def __init__(self):
|
||||
self._client = httpx.AsyncClient()
|
||||
self._client = httpx.AsyncClient(timeout=httpx.Timeout(5.0))
|
||||
|
||||
async def _fetch_text(self, url: str) -> str:
|
||||
"""Fetch the raw HTML content from the URL."""
|
||||
response = await self._client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
response = None
|
||||
try:
|
||||
response = await self._client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception:
|
||||
logging.warning(f"Failed to fetch text from {url}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if response:
|
||||
await response.aclose()
|
||||
|
||||
@abstractmethod
|
||||
async def get_content(self, url: str) -> str: ...
|
||||
|
||||
async def close(self):
|
||||
"""Close the underlying HTTP client."""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure the HTTP client is closed when the object is deleted."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Html2textScraper(TextScraper):
|
||||
@override
|
||||
@@ -39,12 +65,13 @@ class JinaScraper(TextScraper):
|
||||
def __init__(self, api_key: str | None = None):
|
||||
super().__init__()
|
||||
if api_key:
|
||||
self._client.headers.update({"Authorization": "Bearer {api_key}"})
|
||||
self._client.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
|
||||
@override
|
||||
async def get_content(self, url: str) -> str:
|
||||
print(f"Fetching content from: {url}")
|
||||
try:
|
||||
return await self._fetch_text(f"https://r.jina.ai/{url}")
|
||||
except httpx.HTTPStatusError:
|
||||
except Exception:
|
||||
logging.warning(f"Failed to fetch content from {url}", exc_info=True)
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user