Use Chroma vector store
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -12,3 +12,6 @@ wheels/
|
|||||||
# MLflow
|
# MLflow
|
||||||
mlruns/
|
mlruns/
|
||||||
mlartifacts/
|
mlartifacts/
|
||||||
|
|
||||||
|
# ChromaDB
|
||||||
|
chroma_db/
|
||||||
83
indexing.py
83
indexing.py
@@ -5,17 +5,56 @@ import langchain
|
|||||||
import langchain.chat_models
|
import langchain.chat_models
|
||||||
import langchain.hub
|
import langchain.hub
|
||||||
import langchain.text_splitter
|
import langchain.text_splitter
|
||||||
|
import langchain_chroma
|
||||||
import langchain_core
|
import langchain_core
|
||||||
import langchain_core.documents
|
import langchain_core.documents
|
||||||
import langchain_core.vectorstores
|
|
||||||
import langchain_openai
|
import langchain_openai
|
||||||
import langgraph
|
|
||||||
import langgraph.graph
|
import langgraph.graph
|
||||||
import mlflow
|
import mlflow
|
||||||
|
|
||||||
from hn import HackerNewsClient, Story
|
from hn import HackerNewsClient, Story
|
||||||
from scrape import JinaScraper
|
from scrape import JinaScraper
|
||||||
|
|
||||||
|
llm = langchain.chat_models.init_chat_model(
|
||||||
|
model="gpt-4.1-nano", model_provider="openai"
|
||||||
|
)
|
||||||
|
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
|
||||||
|
vector_store = langchain_chroma.Chroma(
|
||||||
|
collection_name="hn_stories",
|
||||||
|
embedding_function=embeddings,
|
||||||
|
persist_directory="./chroma_db",
|
||||||
|
create_collection_if_not_exists=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class State(TypedDict):
|
||||||
|
question: str
|
||||||
|
context: list[langchain_core.documents.Document]
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
# Define application steps
|
||||||
|
def retrieve(state: State):
|
||||||
|
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
|
||||||
|
return {"context": retrieved_docs}
|
||||||
|
|
||||||
|
|
||||||
|
def generate(state: State):
|
||||||
|
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
||||||
|
prompt = langchain.hub.pull("rlm/rag-prompt")
|
||||||
|
messages = prompt.invoke({"question": state["question"], "context": docs_content})
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
return {"answer": response.content}
|
||||||
|
|
||||||
|
|
||||||
|
def run_query(question: str):
|
||||||
|
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
|
||||||
|
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
||||||
|
graph = graph_builder.compile()
|
||||||
|
|
||||||
|
response = graph.invoke({"question": question})
|
||||||
|
print(response["answer"])
|
||||||
|
|
||||||
|
|
||||||
async def fetch_hn_top_stories(
|
async def fetch_hn_top_stories(
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
@@ -57,14 +96,8 @@ async def main():
|
|||||||
mlflow.set_experiment("langchain-rag-hn")
|
mlflow.set_experiment("langchain-rag-hn")
|
||||||
mlflow.langchain.autolog()
|
mlflow.langchain.autolog()
|
||||||
|
|
||||||
llm = langchain.chat_models.init_chat_model(
|
|
||||||
model="gpt-4o-mini", model_provider="openai"
|
|
||||||
)
|
|
||||||
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small")
|
|
||||||
vector_store = langchain_core.vectorstores.InMemoryVectorStore(embeddings)
|
|
||||||
|
|
||||||
# 1. Load
|
# 1. Load
|
||||||
stories = await fetch_hn_top_stories(limit=20)
|
stories = await fetch_hn_top_stories(limit=3)
|
||||||
|
|
||||||
# 2. Split
|
# 2. Split
|
||||||
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
||||||
@@ -76,36 +109,8 @@ async def main():
|
|||||||
_ = vector_store.add_documents(all_splits)
|
_ = vector_store.add_documents(all_splits)
|
||||||
|
|
||||||
# 4. Query
|
# 4. Query
|
||||||
prompt = langchain.hub.pull("rlm/rag-prompt")
|
question = "What are the top stories related to AI and Machine Learning right now?"
|
||||||
|
run_query(question)
|
||||||
# Define state for application
|
|
||||||
class State(TypedDict):
|
|
||||||
question: str
|
|
||||||
context: list[langchain_core.documents.Document]
|
|
||||||
answer: str
|
|
||||||
|
|
||||||
# Define application steps
|
|
||||||
def retrieve(state: State):
|
|
||||||
retrieved_docs = vector_store.similarity_search(state["question"], k=10)
|
|
||||||
return {"context": retrieved_docs}
|
|
||||||
|
|
||||||
def generate(state: State):
|
|
||||||
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
|
|
||||||
messages = prompt.invoke(
|
|
||||||
{"question": state["question"], "context": docs_content}
|
|
||||||
)
|
|
||||||
response = llm.invoke(messages)
|
|
||||||
return {"answer": response.content}
|
|
||||||
|
|
||||||
# Compile application and test
|
|
||||||
graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate])
|
|
||||||
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
|
||||||
graph = graph_builder.compile()
|
|
||||||
|
|
||||||
response = graph.invoke(
|
|
||||||
{"question": "Are there any news stories related to AI and Machine Learning?"}
|
|
||||||
)
|
|
||||||
print(response["answer"])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ dependencies = [
|
|||||||
"hackernews>=2.0.0",
|
"hackernews>=2.0.0",
|
||||||
"html2text>=2025.4.15",
|
"html2text>=2025.4.15",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
|
"langchain-chroma>=0.2.4",
|
||||||
"langchain[openai]>=0.3.26",
|
"langchain[openai]>=0.3.26",
|
||||||
"langgraph>=0.5.0",
|
"langgraph>=0.5.0",
|
||||||
"mlflow>=3.1.1",
|
"mlflow>=3.1.1",
|
||||||
|
|||||||
@@ -44,4 +44,7 @@ class JinaScraper(TextScraper):
|
|||||||
@override
|
@override
|
||||||
async def get_content(self, url: str) -> str:
|
async def get_content(self, url: str) -> str:
|
||||||
print(f"Fetching content from: {url}")
|
print(f"Fetching content from: {url}")
|
||||||
|
try:
|
||||||
return await self._fetch_text(f"https://r.jina.ai/{url}")
|
return await self._fetch_text(f"https://r.jina.ai/{url}")
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
return ""
|
||||||
|
|||||||
Reference in New Issue
Block a user