Use Chroma vector store
This commit is contained in:
		
							
								
								
									
										83
									
								
								indexing.py
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								indexing.py
									
									
									
									
									
								
							| @@ -5,17 +5,56 @@ import langchain | ||||
| import langchain.chat_models | ||||
| import langchain.hub | ||||
| import langchain.text_splitter | ||||
| import langchain_chroma | ||||
| import langchain_core | ||||
| import langchain_core.documents | ||||
| import langchain_core.vectorstores | ||||
| import langchain_openai | ||||
| import langgraph | ||||
| import langgraph.graph | ||||
| import mlflow | ||||
|  | ||||
| from hn import HackerNewsClient, Story | ||||
| from scrape import JinaScraper | ||||
|  | ||||
| llm = langchain.chat_models.init_chat_model( | ||||
|     model="gpt-4.1-nano", model_provider="openai" | ||||
| ) | ||||
| embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small") | ||||
| vector_store = langchain_chroma.Chroma( | ||||
|     collection_name="hn_stories", | ||||
|     embedding_function=embeddings, | ||||
|     persist_directory="./chroma_db", | ||||
|     create_collection_if_not_exists=True, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class State(TypedDict): | ||||
|     question: str | ||||
|     context: list[langchain_core.documents.Document] | ||||
|     answer: str | ||||
|  | ||||
|  | ||||
| # Define application steps | ||||
| def retrieve(state: State): | ||||
|     retrieved_docs = vector_store.similarity_search(state["question"], k=10) | ||||
|     return {"context": retrieved_docs} | ||||
|  | ||||
|  | ||||
| def generate(state: State): | ||||
|     docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | ||||
|     prompt = langchain.hub.pull("rlm/rag-prompt") | ||||
|     messages = prompt.invoke({"question": state["question"], "context": docs_content}) | ||||
|     response = llm.invoke(messages) | ||||
|     return {"answer": response.content} | ||||
|  | ||||
|  | ||||
| def run_query(question: str): | ||||
|     graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate]) | ||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve") | ||||
|     graph = graph_builder.compile() | ||||
|  | ||||
|     response = graph.invoke({"question": question}) | ||||
|     print(response["answer"]) | ||||
|  | ||||
|  | ||||
| async def fetch_hn_top_stories( | ||||
|     limit: int = 10, | ||||
| @@ -57,14 +96,8 @@ async def main(): | ||||
|     mlflow.set_experiment("langchain-rag-hn") | ||||
|     mlflow.langchain.autolog() | ||||
|  | ||||
|     llm = langchain.chat_models.init_chat_model( | ||||
|         model="gpt-4o-mini", model_provider="openai" | ||||
|     ) | ||||
|     embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-small") | ||||
|     vector_store = langchain_core.vectorstores.InMemoryVectorStore(embeddings) | ||||
|  | ||||
|     # 1. Load | ||||
|     stories = await fetch_hn_top_stories(limit=20) | ||||
|     stories = await fetch_hn_top_stories(limit=3) | ||||
|  | ||||
|     # 2. Split | ||||
|     splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||
| @@ -76,36 +109,8 @@ async def main(): | ||||
|     _ = vector_store.add_documents(all_splits) | ||||
|  | ||||
|     # 4. Query | ||||
|     prompt = langchain.hub.pull("rlm/rag-prompt") | ||||
|  | ||||
|     # Define state for application | ||||
|     class State(TypedDict): | ||||
|         question: str | ||||
|         context: list[langchain_core.documents.Document] | ||||
|         answer: str | ||||
|  | ||||
|     # Define application steps | ||||
|     def retrieve(state: State): | ||||
|         retrieved_docs = vector_store.similarity_search(state["question"], k=10) | ||||
|         return {"context": retrieved_docs} | ||||
|  | ||||
|     def generate(state: State): | ||||
|         docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | ||||
|         messages = prompt.invoke( | ||||
|             {"question": state["question"], "context": docs_content} | ||||
|         ) | ||||
|         response = llm.invoke(messages) | ||||
|         return {"answer": response.content} | ||||
|  | ||||
|     # Compile application and test | ||||
|     graph_builder = langgraph.graph.StateGraph(State).add_sequence([retrieve, generate]) | ||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve") | ||||
|     graph = graph_builder.compile() | ||||
|  | ||||
|     response = graph.invoke( | ||||
|         {"question": "Are there any news stories related to AI and Machine Learning?"} | ||||
|     ) | ||||
|     print(response["answer"]) | ||||
|     question = "What are the top stories related to AI and Machine Learning right now?" | ||||
|     run_query(question) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user