feat: Langchain HN RAG demo

This commit is contained in:
Adrian Rumpold
2025-07-01 09:26:52 +02:00
commit 648baf3263
6 changed files with 2193 additions and 0 deletions

14
.gitignore vendored Normal file
View File

@@ -0,0 +1,14 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# MLflow
mlruns/
mlartifacts/

46
hn.py Normal file
View File

@@ -0,0 +1,46 @@
from datetime import datetime
import dateutil.parser
import httpx
from pydantic import BaseModel, Field
class Story(BaseModel):
"""Model representing a Hacker News story."""
id: str
title: str = Field(description="Title of the story.")
url: str | None = Field(description="URL of the story.")
created_at: datetime
class HackerNewsClient:
def __init__(self, client: httpx.Client | None = None):
base_url = "https://hn.algolia.com/api/v1"
if client:
client.base_url = base_url
self._client = client
else:
self._client = httpx.Client(base_url=base_url)
def get_top_stories(self, limit: int = 10) -> list[Story]:
resp = self._client.get(
"search",
params={"tags": "front_page", "hitsPerPage": limit, "page": 0},
)
resp.raise_for_status()
return [
Story(
id=hit["objectID"],
title=hit["title"],
url=hit.get("url"),
created_at=dateutil.parser.isoparse(hit["created_at"]),
)
for hit in resp.json().get("hits", [])
]
def get_item(self, item_id):
resp = self._client.get(f"items/{item_id}")
resp.raise_for_status()
return resp.json()

114
indexing.py Normal file
View File

@@ -0,0 +1,114 @@
import os
from typing import TypedDict
import langchain
import langchain.chat_models
import langchain.hub
import langchain.text_splitter
import langchain_core
import langchain_core.documents
import langchain_core.vectorstores
import langchain_openai
import langgraph
import langgraph.graph
import mlflow
from hn import HackerNewsClient, Story
from scrape import JinaScraper
async def fetch_hn_top_stories(
limit: int = 10,
) -> list[langchain_core.documents.Document]:
hn = HackerNewsClient()
stories = hn.get_top_stories(limit=limit)
contents = {}
# Fetch content for each story asynchronously
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
async def _fetch_content(story: Story) -> tuple[str, str]:
if not story.url:
return story.id, story.title
return story.id, await scraper.get_content(story.url)
tasks = [_fetch_content(story) for story in stories]
results = await asyncio.gather(*tasks)
contents = dict(results)
documents = [
langchain_core.documents.Document(
page_content=contents[story.id],
metadata={
"id": story.id,
"title": story.title,
"source": story.url,
"created_at": story.created_at.isoformat(),
},
)
for story in stories
]
return documents
async def main():
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("langchain-rag-hn")
mlflow.langchain.autolog()
llm = langchain.chat_models.init_chat_model(
model="gpt-4o-mini", model_provider="openai"
)
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
vector_store = langchain_core.vectorstores.InMemoryVectorStore(embeddings)
# 1. Load
stories = await fetch_hn_top_stories(limit=20)
# 2. Split
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
)
all_splits = splitter.split_documents(stories)
# 3. Store
_ = vector_store.add_documents(all_splits)
# 4. Query
prompt = langchain.hub.pull("rlm/rag-prompt")
# Define state for application
class State(TypedDict):
question: str
context: list[langchain_core.documents.Document]
answer: str
# Define application steps
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke(
{"question": state["question"], "context": docs_content}
)
response = llm.invoke(messages)
return {"answer": response.content}
# Compile application and test
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(langgraph.graph.START, "retrieve")
graph = graph_builder.compile()
response = graph.invoke(
{"question": "Are there any news stories related to AI and Machine Learning?"}
)
print(response["answer"])
if __name__ == "__main__":
import asyncio
asyncio.run(main())

16
pyproject.toml Normal file
View File

@@ -0,0 +1,16 @@
[project]
name = "langchain-demo"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"hackernews>=2.0.0",
"html2text>=2025.4.15",
"httpx>=0.28.1",
"langchain[openai]>=0.3.26",
"langgraph>=0.5.0",
"mlflow>=3.1.1",
"python-dateutil>=2.9.0.post0",
"readability-lxml>=0.8.4.1",
]

47
scrape.py Normal file
View File

@@ -0,0 +1,47 @@
from abc import ABC, abstractmethod
from typing import override
import httpx
class TextScraper(ABC):
def __init__(self):
self._client = httpx.AsyncClient()
async def _fetch_text(self, url: str) -> str:
"""Fetch the raw HTML content from the URL."""
response = await self._client.get(url)
response.raise_for_status()
return response.text
@abstractmethod
async def get_content(self, url: str) -> str: ...
class Html2textScraper(TextScraper):
@override
async def get_content(self, url: str) -> str:
import html2text
return html2text.html2text(await self._fetch_text(url))
class ReadabilityScraper(TextScraper):
@override
async def get_content(self, url: str) -> str:
import readability
doc = readability.Document(await self._fetch_text(url))
return doc.summary(html_partial=True)
class JinaScraper(TextScraper):
def __init__(self, api_key: str | None = None):
super().__init__()
if api_key:
self._client.headers.update({"Authorization": "Bearer {api_key}"})
@override
async def get_content(self, url: str) -> str:
print(f"Fetching content from: {url}")
return await self._fetch_text(f"https://r.jina.ai/{url}")

1956
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff