feat: Langchain HN RAG demo
This commit is contained in:
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# MLflow
|
||||||
|
mlruns/
|
||||||
|
mlartifacts/
|
||||||
46
hn.py
Normal file
46
hn.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import dateutil.parser
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Story(BaseModel):
|
||||||
|
"""Model representing a Hacker News story."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
title: str = Field(description="Title of the story.")
|
||||||
|
url: str | None = Field(description="URL of the story.")
|
||||||
|
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class HackerNewsClient:
|
||||||
|
def __init__(self, client: httpx.Client | None = None):
|
||||||
|
base_url = "https://hn.algolia.com/api/v1"
|
||||||
|
if client:
|
||||||
|
client.base_url = base_url
|
||||||
|
self._client = client
|
||||||
|
else:
|
||||||
|
self._client = httpx.Client(base_url=base_url)
|
||||||
|
|
||||||
|
def get_top_stories(self, limit: int = 10) -> list[Story]:
|
||||||
|
resp = self._client.get(
|
||||||
|
"search",
|
||||||
|
params={"tags": "front_page", "hitsPerPage": limit, "page": 0},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [
|
||||||
|
Story(
|
||||||
|
id=hit["objectID"],
|
||||||
|
title=hit["title"],
|
||||||
|
url=hit.get("url"),
|
||||||
|
created_at=dateutil.parser.isoparse(hit["created_at"]),
|
||||||
|
)
|
||||||
|
for hit in resp.json().get("hits", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_item(self, item_id):
|
||||||
|
resp = self._client.get(f"items/{item_id}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
114
indexing.py
Normal file
114
indexing.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
import langchain
|
||||||
|
import langchain.chat_models
|
||||||
|
import langchain.hub
|
||||||
|
import langchain.text_splitter
|
||||||
|
import langchain_core
|
||||||
|
import langchain_core.documents
|
||||||
|
import langchain_core.vectorstores
|
||||||
|
import langchain_openai
|
||||||
|
import langgraph
|
||||||
|
import langgraph.graph
|
||||||
|
import mlflow
|
||||||
|
|
||||||
|
from hn import HackerNewsClient, Story
|
||||||
|
from scrape import JinaScraper
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_hn_top_stories(
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[langchain_core.documents.Document]:
|
||||||
|
hn = HackerNewsClient()
|
||||||
|
stories = hn.get_top_stories(limit=limit)
|
||||||
|
|
||||||
|
contents = {}
|
||||||
|
|
||||||
|
# Fetch content for each story asynchronously
|
||||||
|
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
|
||||||
|
|
||||||
|
async def _fetch_content(story: Story) -> tuple[str, str]:
|
||||||
|
if not story.url:
|
||||||
|
return story.id, story.title
|
||||||
|
return story.id, await scraper.get_content(story.url)
|
||||||
|
|
||||||
|
tasks = [_fetch_content(story) for story in stories]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
contents = dict(results)
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
langchain_core.documents.Document(
|
||||||
|
page_content=contents[story.id],
|
||||||
|
metadata={
|
||||||
|
"id": story.id,
|
||||||
|
"title": story.title,
|
||||||
|
"source": story.url,
|
||||||
|
"created_at": story.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for story in stories
|
||||||
|
]
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
mlflow.set_tracking_uri("http://localhost:5000")
|
||||||
|
mlflow.set_experiment("langchain-rag-hn")
|
||||||
|
mlflow.langchain.autolog()
|
||||||
|
|
||||||
|
llm = langchain.chat_models.init_chat_model(
|
||||||
|
model="gpt-4o-mini", model_provider="openai"
|
||||||
|
)
|
||||||
|
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
|
||||||
|
vector_store = langchain_core.vectorstores.InMemoryVectorStore(embeddings)
|
||||||
|
|
||||||
|
# 1. Load
|
||||||
|
stories = await fetch_hn_top_stories(limit=20)
|
||||||
|
|
||||||
|
# 2. Split
|
||||||
|
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000, chunk_overlap=200
|
||||||
|
)
|
||||||
|
all_splits = splitter.split_documents(stories)
|
||||||
|
|
||||||
|
# 3. Store
|
||||||
|
_ = vector_store.add_documents(all_splits)
|
||||||
|
|
||||||
|
# 4. Query
|
||||||
|
prompt = langchain.hub.pull("rlm/rag-prompt")
|
||||||
|
|
||||||
|
# Define state for application
|
||||||
|
class State(TypedDict):
|
||||||
|
question: str
|
||||||
|
context: list[langchain_core.documents.Document]
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
# Define application steps
|
||||||
|
def retrieve(state: State):
|
||||||
|
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
|
||||||
|
return {"context": retrieved_docs}
|
||||||
|
|
||||||
|
def generate(state: State):
|
||||||
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||||
|
messages = prompt.invoke(
|
||||||
|
{"question": state["question"], "context": docs_content}
|
||||||
|
)
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
return {"answer": response.content}
|
||||||
|
|
||||||
|
# Compile application and test
|
||||||
|
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
|
||||||
|
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
||||||
|
graph = graph_builder.compile()
|
||||||
|
|
||||||
|
response = graph.invoke(
|
||||||
|
{"question": "Are there any news stories related to AI and Machine Learning?"}
|
||||||
|
)
|
||||||
|
print(response["answer"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
16
pyproject.toml
Normal file
16
pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
[project]
|
||||||
|
name = "langchain-demo"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"hackernews>=2.0.0",
|
||||||
|
"html2text>=2025.4.15",
|
||||||
|
"httpx>=0.28.1",
|
||||||
|
"langchain[openai]>=0.3.26",
|
||||||
|
"langgraph>=0.5.0",
|
||||||
|
"mlflow>=3.1.1",
|
||||||
|
"python-dateutil>=2.9.0.post0",
|
||||||
|
"readability-lxml>=0.8.4.1",
|
||||||
|
]
|
||||||
47
scrape.py
Normal file
47
scrape.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import override
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class TextScraper(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
self._client = httpx.AsyncClient()
|
||||||
|
|
||||||
|
async def _fetch_text(self, url: str) -> str:
|
||||||
|
"""Fetch the raw HTML content from the URL."""
|
||||||
|
response = await self._client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_content(self, url: str) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
class Html2textScraper(TextScraper):
|
||||||
|
@override
|
||||||
|
async def get_content(self, url: str) -> str:
|
||||||
|
import html2text
|
||||||
|
|
||||||
|
return html2text.html2text(await self._fetch_text(url))
|
||||||
|
|
||||||
|
|
||||||
|
class ReadabilityScraper(TextScraper):
|
||||||
|
@override
|
||||||
|
async def get_content(self, url: str) -> str:
|
||||||
|
import readability
|
||||||
|
|
||||||
|
doc = readability.Document(await self._fetch_text(url))
|
||||||
|
return doc.summary(html_partial=True)
|
||||||
|
|
||||||
|
|
||||||
|
class JinaScraper(TextScraper):
|
||||||
|
def __init__(self, api_key: str | None = None):
|
||||||
|
super().__init__()
|
||||||
|
if api_key:
|
||||||
|
self._client.headers.update({"Authorization": "Bearer {api_key}"})
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_content(self, url: str) -> str:
|
||||||
|
print(f"Fetching content from: {url}")
|
||||||
|
return await self._fetch_text(f"https://r.jina.ai/{url}")
|
||||||
Reference in New Issue
Block a user