Use Chroma vector store
This commit is contained in:
		
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -11,4 +11,7 @@ wheels/ | |||||||
|  |  | ||||||
| # MLflow | # MLflow | ||||||
| mlruns/ | mlruns/ | ||||||
| mlartifacts/ | mlartifacts/ | ||||||
|  |  | ||||||
|  | # ChromaDB | ||||||
|  | chroma_db/ | ||||||
							
								
								
									
										83
									
								
								indexing.py
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								indexing.py
									
									
									
									
									
								
							| @@ -5,17 +5,56 @@ import langchain | |||||||
| import langchain.chat_models | import langchain.chat_models | ||||||
| import langchain.hub | import langchain.hub | ||||||
| import langchain.text_splitter | import langchain.text_splitter | ||||||
|  | import langchain_chroma | ||||||
| import langchain_core | import langchain_core | ||||||
| import langchain_core.documents | import langchain_core.documents | ||||||
| import langchain_core.vectorstores |  | ||||||
| import langchain_openai | import langchain_openai | ||||||
| import langgraph |  | ||||||
| import langgraph.graph | import langgraph.graph | ||||||
| import mlflow | import mlflow | ||||||
|  |  | ||||||
| from hn import HackerNewsClient, Story | from hn import HackerNewsClient, Story | ||||||
| from scrape import JinaScraper | from scrape import JinaScraper | ||||||
|  |  | ||||||
|  | llm = langchain.chat_models.init_chat_model( | ||||||
|  |     model="gpt-4.1-nano", model_provider="openai" | ||||||
|  | ) | ||||||
|  | embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small") | ||||||
|  | vector_store = langchain_chroma.Chroma( | ||||||
|  |     collection_name="hn_stories", | ||||||
|  |     embedding_function=embeddings, | ||||||
|  |     persist_directory="./chroma_db", | ||||||
|  |     create_collection_if_not_exists=True, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class State(TypedDict): | ||||||
|  |     question: str | ||||||
|  |     context: list[langchain_core.documents.Document] | ||||||
|  |     answer: str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Define application steps | ||||||
|  | def retrieve(state: State): | ||||||
|  |     retrieved_docs = vector_store.similarity_search(state["question"], k=10) | ||||||
|  |     return {"context": retrieved_docs} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def generate(state: State): | ||||||
|  |     docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | ||||||
|  |     prompt = langchain.hub.pull("rlm/rag-prompt") | ||||||
|  |     messages = prompt.invoke({"question": state["question"], "context": docs_content}) | ||||||
|  |     response = llm.invoke(messages) | ||||||
|  |     return {"answer": response.content} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_query(question: str): | ||||||
|  |     graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate]) | ||||||
|  |     graph_builder.add_edge(langgraph.graph.START, "retrieve") | ||||||
|  |     graph = graph_builder.compile() | ||||||
|  |  | ||||||
|  |     response = graph.invoke({"question": question}) | ||||||
|  |     print(response["answer"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def fetch_hn_top_stories( | async def fetch_hn_top_stories( | ||||||
|     limit: int = 10, |     limit: int = 10, | ||||||
| @@ -57,14 +96,8 @@ async def main(): | |||||||
|     mlflow.set_experiment("langchain-rag-hn") |     mlflow.set_experiment("langchain-rag-hn") | ||||||
|     mlflow.langchain.autolog() |     mlflow.langchain.autolog() | ||||||
|  |  | ||||||
|     llm = langchain.chat_models.init_chat_model( |  | ||||||
|         model="gpt-4o-mini", model_provider="openai" |  | ||||||
|     ) |  | ||||||
|     embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small") |  | ||||||
|     vector_store = langchain_core.vectorstores.InMemoryVectorStore(embeddings) |  | ||||||
|  |  | ||||||
|     # 1. Load |     # 1. Load | ||||||
|     stories = await fetch_hn_top_stories(limit=20) |     stories = await fetch_hn_top_stories(limit=3) | ||||||
|  |  | ||||||
|     # 2. Split |     # 2. Split | ||||||
|     splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( |     splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||||
| @@ -76,36 +109,8 @@ async def main(): | |||||||
|     _ = vector_store.add_documents(all_splits) |     _ = vector_store.add_documents(all_splits) | ||||||
|  |  | ||||||
|     # 4. Query |     # 4. Query | ||||||
|     prompt = langchain.hub.pull("rlm/rag-prompt") |     question = "What are the top stories related to AI and Machine Learning right now?" | ||||||
|  |     run_query(question) | ||||||
|     # Define state for application |  | ||||||
|     class State(TypedDict): |  | ||||||
|         question: str |  | ||||||
|         context: list[langchain_core.documents.Document] |  | ||||||
|         answer: str |  | ||||||
|  |  | ||||||
|     # Define application steps |  | ||||||
|     def retrieve(state: State): |  | ||||||
|         retrieved_docs = vector_store.similarity_search(state["question"], k=10) |  | ||||||
|         return {"context": retrieved_docs} |  | ||||||
|  |  | ||||||
|     def generate(state: State): |  | ||||||
|         docs_content = "\n\n".join(doc.page_content for doc in state["context"]) |  | ||||||
|         messages = prompt.invoke( |  | ||||||
|             {"question": state["question"], "context": docs_content} |  | ||||||
|         ) |  | ||||||
|         response = llm.invoke(messages) |  | ||||||
|         return {"answer": response.content} |  | ||||||
|  |  | ||||||
|     # Compile application and test |  | ||||||
|     graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate]) |  | ||||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve") |  | ||||||
|     graph = graph_builder.compile() |  | ||||||
|  |  | ||||||
|     response = graph.invoke( |  | ||||||
|         {"question": "Are there any news stories related to AI and Machine Learning?"} |  | ||||||
|     ) |  | ||||||
|     print(response["answer"]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ dependencies = [ | |||||||
|     "hackernews>=2.0.0", |     "hackernews>=2.0.0", | ||||||
|     "html2text>=2025.4.15", |     "html2text>=2025.4.15", | ||||||
|     "httpx>=0.28.1", |     "httpx>=0.28.1", | ||||||
|  |     "langchain-chroma>=0.2.4", | ||||||
|     "langchain[openai]>=0.3.26", |     "langchain[openai]>=0.3.26", | ||||||
|     "langgraph>=0.5.0", |     "langgraph>=0.5.0", | ||||||
|     "mlflow>=3.1.1", |     "mlflow>=3.1.1", | ||||||
|   | |||||||
| @@ -44,4 +44,7 @@ class JinaScraper(TextScraper): | |||||||
|     @override |     @override | ||||||
|     async def get_content(self, url: str) -> str: |     async def get_content(self, url: str) -> str: | ||||||
|         print(f"Fetching content from: {url}") |         print(f"Fetching content from: {url}") | ||||||
|         return await self._fetch_text(f"https://r.jina.ai/{url}") |         try: | ||||||
|  |             return await self._fetch_text(f"https://r.jina.ai/{url}") | ||||||
|  |         except httpx.HTTPStatusError: | ||||||
|  |             return "" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user