120 lines
3.2 KiB
Python
120 lines
3.2 KiB
Python
import os
|
|
from typing import TypedDict
|
|
|
|
import langchain
|
|
import langchain.chat_models
|
|
import langchain.hub
|
|
import langchain.text_splitter
|
|
import langchain_chroma
|
|
import langchain_core
|
|
import langchain_core.documents
|
|
import langchain_openai
|
|
import langgraph.graph
|
|
import mlflow
|
|
|
|
from hn import HackerNewsClient, Story
|
|
from scrape import JinaScraper
|
|
|
|
llm = langchain.chat_models.init_chat_model(
|
|
model="gpt-4.1-nano", model_provider="openai"
|
|
)
|
|
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
|
|
vector_store = langchain_chroma.Chroma(
|
|
collection_name="hn_stories",
|
|
embedding_function=embeddings,
|
|
persist_directory="./chroma_db",
|
|
create_collection_if_not_exists=True,
|
|
)
|
|
|
|
|
|
class State(TypedDict):
|
|
question: str
|
|
context: list[langchain_core.documents.Document]
|
|
answer: str
|
|
|
|
|
|
# Define application steps
|
|
def retrieve(state: State):
|
|
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
|
|
return {"context": retrieved_docs}
|
|
|
|
|
|
def generate(state: State):
|
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
|
prompt = langchain.hub.pull("rlm/rag-prompt")
|
|
messages = prompt.invoke({"question": state["question"], "context": docs_content})
|
|
response = llm.invoke(messages)
|
|
return {"answer": response.content}
|
|
|
|
|
|
def run_query(question: str):
|
|
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
|
|
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
|
graph = graph_builder.compile()
|
|
|
|
response = graph.invoke({"question": question})
|
|
print(response["answer"])
|
|
|
|
|
|
async def fetch_hn_top_stories(
|
|
limit: int = 10,
|
|
) -> list[langchain_core.documents.Document]:
|
|
hn = HackerNewsClient()
|
|
stories = hn.get_top_stories(limit=limit)
|
|
|
|
contents = {}
|
|
|
|
# Fetch content for each story asynchronously
|
|
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
|
|
|
|
async def _fetch_content(story: Story) -> tuple[str, str]:
|
|
if not story.url:
|
|
return story.id, story.title
|
|
return story.id, await scraper.get_content(story.url)
|
|
|
|
tasks = [_fetch_content(story) for story in stories]
|
|
results = await asyncio.gather(*tasks)
|
|
contents = dict(results)
|
|
|
|
documents = [
|
|
langchain_core.documents.Document(
|
|
page_content=contents[story.id],
|
|
metadata={
|
|
"id": story.id,
|
|
"title": story.title,
|
|
"source": story.url,
|
|
"created_at": story.created_at.isoformat(),
|
|
},
|
|
)
|
|
for story in stories
|
|
]
|
|
return documents
|
|
|
|
|
|
async def main():
|
|
mlflow.set_tracking_uri("http://localhost:5000")
|
|
mlflow.set_experiment("langchain-rag-hn")
|
|
mlflow.langchain.autolog()
|
|
|
|
# 1. Load
|
|
stories = await fetch_hn_top_stories(limit=3)
|
|
|
|
# 2. Split
|
|
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
|
chunk_size=1000, chunk_overlap=200
|
|
)
|
|
all_splits = splitter.split_documents(stories)
|
|
|
|
# 3. Store
|
|
_ = vector_store.add_documents(all_splits)
|
|
|
|
# 4. Query
|
|
question = "What are the top stories related to AI and Machine Learning right now?"
|
|
run_query(question)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
asyncio.run(main())
|