Use Chroma vector store

This commit is contained in:
Adrian Rumpold
2025-07-01 10:48:56 +02:00
parent 648baf3263
commit f093a488f3
6 changed files with 981 additions and 50 deletions

5
.gitignore vendored
View File

@@ -11,4 +11,7 @@ wheels/
# MLflow
mlruns/
mlartifacts/
mlartifacts/
# ChromaDB
chroma_db/

0
README.md Normal file
View File

View File

@@ -5,17 +5,56 @@ import langchain
import langchain.chat_models
import langchain.hub
import langchain.text_splitter
import langchain_chroma
import langchain_core
import langchain_core.documents
import langchain_core.vectorstores
import langchain_openai
import langgraph
import langgraph.graph
import mlflow
from hn import HackerNewsClient, Story
from scrape import JinaScraper
llm = langchain.chat_models.init_chat_model(
model="gpt-4.1-nano", model_provider="openai"
)
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
vector_store = langchain_chroma.Chroma(
collection_name="hn_stories",
embedding_function=embeddings,
persist_directory="./chroma_db",
create_collection_if_not_exists=True,
)
class State(TypedDict):
question: str
context: list[langchain_core.documents.Document]
answer: str
# Define application steps
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
prompt = langchain.hub.pull("rlm/rag-prompt")
messages = prompt.invoke({"question": state["question"], "context": docs_content})
response = llm.invoke(messages)
return {"answer": response.content}
def run_query(question: str):
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(langgraph.graph.START, "retrieve")
graph = graph_builder.compile()
response = graph.invoke({"question": question})
print(response["answer"])
async def fetch_hn_top_stories(
limit: int = 10,
@@ -57,14 +96,8 @@ async def main():
mlflow.set_experiment("langchain-rag-hn")
mlflow.langchain.autolog()
llm = langchain.chat_models.init_chat_model(
model="gpt-4o-mini", model_provider="openai"
)
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
vector_store = langchain_core.vectorstores.InMemoryVectorStore(embeddings)
# 1. Load
stories = await fetch_hn_top_stories(limit=20)
stories = await fetch_hn_top_stories(limit=3)
# 2. Split
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
@@ -76,36 +109,8 @@ async def main():
_ = vector_store.add_documents(all_splits)
# 4. Query
prompt = langchain.hub.pull("rlm/rag-prompt")
# Define state for application
class State(TypedDict):
question: str
context: list[langchain_core.documents.Document]
answer: str
# Define application steps
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke(
{"question": state["question"], "context": docs_content}
)
response = llm.invoke(messages)
return {"answer": response.content}
# Compile application and test
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(langgraph.graph.START, "retrieve")
graph = graph_builder.compile()
response = graph.invoke(
{"question": "Are there any news stories related to AI and Machine Learning?"}
)
print(response["answer"])
question = "What are the top stories related to AI and Machine Learning right now?"
run_query(question)
if __name__ == "__main__":

View File

@@ -8,6 +8,7 @@ dependencies = [
"hackernews>=2.0.0",
"html2text>=2025.4.15",
"httpx>=0.28.1",
"langchain-chroma>=0.2.4",
"langchain[openai]>=0.3.26",
"langgraph>=0.5.0",
"mlflow>=3.1.1",

View File

@@ -44,4 +44,7 @@ class JinaScraper(TextScraper):
@override
async def get_content(self, url: str) -> str:
print(f"Fetching content from: {url}")
return await self._fetch_text(f"https://r.jina.ai/{url}")
try:
return await self._fetch_text(f"https://r.jina.ai/{url}")
except httpx.HTTPStatusError:
return ""

937
uv.lock generated

File diff suppressed because it is too large Load Diff