Compare commits
	
		
			2 Commits
		
	
	
		
			87a17331fd
			...
			311c332b10
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 311c332b10 | ||
|  | 259c9699ad | 
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -15,3 +15,5 @@ mlartifacts/ | |||||||
|  |  | ||||||
| # Weaviate vector store | # Weaviate vector store | ||||||
| weaviate/ | weaviate/ | ||||||
|  |  | ||||||
|  | data/ | ||||||
							
								
								
									
										145
									
								
								indexing.py
									
									
									
									
									
								
							
							
						
						
									
										145
									
								
								indexing.py
									
									
									
									
									
								
							| @@ -1,3 +1,4 @@ | |||||||
|  | import asyncio | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
| from typing import Iterable, TypedDict | from typing import Iterable, TypedDict | ||||||
| @@ -5,6 +6,7 @@ from typing import Iterable, TypedDict | |||||||
| import langchain | import langchain | ||||||
| import langchain.chat_models | import langchain.chat_models | ||||||
| import langchain.prompts | import langchain.prompts | ||||||
|  | import langchain.retrievers | ||||||
| import langchain.text_splitter | import langchain.text_splitter | ||||||
| import langchain_core | import langchain_core | ||||||
| import langchain_core.documents | import langchain_core.documents | ||||||
| @@ -16,11 +18,12 @@ import slack | |||||||
| import weaviate | import weaviate | ||||||
| from hn import HackerNewsClient, Story | from hn import HackerNewsClient, Story | ||||||
| from scrape import JinaScraper | from scrape import JinaScraper | ||||||
|  | from util import DocumentLocalFileStore | ||||||
|  |  | ||||||
| NUM_STORIES = 20 | NUM_STORIES = 40  # Number of top stories to fetch from Hacker News | ||||||
| USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"] | USER_PREFERENCES = ["Machine Learning", "Programming", "DevOps"] | ||||||
| ENABLE_SLACK = True  # Send updates to Slack, need to set SLACK_BOT_TOKEN env var | ENABLE_SLACK = True  # Send updates to Slack, need to set SLACK_BOT_TOKEN env var | ||||||
| ENABLE_MLFLOW_TRACING = False  # Use MLflow (at http://localhost:5000) for tracing | ENABLE_MLFLOW_TRACING = True  # Use MLflow (at http://localhost:5000) for tracing | ||||||
|  |  | ||||||
|  |  | ||||||
| llm = langchain.chat_models.init_chat_model( | llm = langchain.chat_models.init_chat_model( | ||||||
| @@ -36,6 +39,18 @@ vector_store = langchain_weaviate.WeaviateVectorStore( | |||||||
|     embedding=embeddings, |     embedding=embeddings, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( | ||||||
|  |     chunk_size=2000, | ||||||
|  |     chunk_overlap=200, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | doc_store = DocumentLocalFileStore(root_path="data/documents") | ||||||
|  | retriever = langchain.retrievers.ParentDocumentRetriever( | ||||||
|  |     vectorstore=vector_store, | ||||||
|  |     docstore=doc_store, | ||||||
|  |     child_splitter=splitter, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class State(TypedDict): | class State(TypedDict): | ||||||
|     preferences: Iterable[str] |     preferences: Iterable[str] | ||||||
| @@ -44,70 +59,15 @@ class State(TypedDict): | |||||||
|     summaries: list[dict] |     summaries: list[dict] | ||||||
|  |  | ||||||
|  |  | ||||||
| def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State: | def retrieve_docs(state: State, top_n: int = 3): | ||||||
|     # Search for relevant documents (with scores if available) |     """Retrieve relevant documents based on user preferences.""" | ||||||
|     retrieved_docs = vector_store.similarity_search( |  | ||||||
|         "Show the most interesting articles about the following topics: " |  | ||||||
|         + ", ".join(state["preferences"]), |  | ||||||
|         k=top_n * 20,  # Chunks, not complete stories |  | ||||||
|         return_score=True |  | ||||||
|         if hasattr(vector_store, "similarity_search_with_score") |  | ||||||
|         else False, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # If scores are returned, unpack (doc, score) tuples; else, set score to None |     docs = [] | ||||||
|     docs_with_scores = [] |     for preference in state["preferences"]: | ||||||
|     if retrieved_docs and isinstance(retrieved_docs[0], tuple): |         logging.info(f"Retrieving documents for preference: {preference}") | ||||||
|         for doc, score in retrieved_docs: |         docs.extend(retriever.invoke(preference)[:top_n]) | ||||||
|             docs_with_scores.append((doc, score)) |  | ||||||
|     else: |  | ||||||
|         for doc in retrieved_docs: |  | ||||||
|             docs_with_scores.append((doc, None)) |  | ||||||
|  |  | ||||||
|     # Group chunks by story_id and collect their scores |     return {"context": docs} | ||||||
|     story_groups = {} |  | ||||||
|     for doc, score in docs_with_scores: |  | ||||||
|         story_id = doc.metadata.get("story_id") |  | ||||||
|         if story_id not in story_groups: |  | ||||||
|             story_groups[story_id] = [] |  | ||||||
|         story_groups[story_id].append((doc, score)) |  | ||||||
|  |  | ||||||
|     # Aggregate max score per story and reconstruct complete stories |  | ||||||
|     story_scores = {} |  | ||||||
|     complete_stories = [] |  | ||||||
|     for story_id, chunks_scores in story_groups.items(): |  | ||||||
|         chunks = [doc for doc, _ in chunks_scores] |  | ||||||
|         scores = [s for _, s in chunks_scores if s is not None] |  | ||||||
|         max_score = max(scores) if scores else None |  | ||||||
|         story_scores[story_id] = max_score |  | ||||||
|         if len(chunks) == 1: |  | ||||||
|             complete_stories.append((chunks[0], max_score)) |  | ||||||
|         else: |  | ||||||
|             combined_content = "\n\n".join( |  | ||||||
|                 chunk.page_content |  | ||||||
|                 for chunk in sorted( |  | ||||||
|                     chunks, key=lambda x: x.metadata.get("chunk_index", 0) |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|             complete_story = langchain_core.documents.Document( |  | ||||||
|                 page_content=combined_content, |  | ||||||
|                 metadata=chunks[0].metadata,  # Use metadata from first chunk |  | ||||||
|             ) |  | ||||||
|             complete_stories.append((complete_story, max_score)) |  | ||||||
|  |  | ||||||
|     # Sort stories by max_score descending (None scores go last) |  | ||||||
|     complete_stories_sorted = sorted( |  | ||||||
|         complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # Return top_n stories |  | ||||||
|     top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]] |  | ||||||
|     return { |  | ||||||
|         "preferences": state["preferences"], |  | ||||||
|         "context": top_stories, |  | ||||||
|         "answer": state.get("answer", ""), |  | ||||||
|         "summaries": state.get("summaries", []), |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_structured_summaries(state: State): | def generate_structured_summaries(state: State): | ||||||
| @@ -217,9 +177,9 @@ def run_structured_query( | |||||||
| ) -> list[dict]: | ) -> list[dict]: | ||||||
|     """Run query and return structured summary data.""" |     """Run query and return structured summary data.""" | ||||||
|     graph_builder = langgraph.graph.StateGraph(State).add_sequence( |     graph_builder = langgraph.graph.StateGraph(State).add_sequence( | ||||||
|         [retrieve, generate_structured_summaries] |         [retrieve_docs, generate_structured_summaries] | ||||||
|     ) |     ) | ||||||
|     graph_builder.add_edge(langgraph.graph.START, "retrieve") |     graph_builder.add_edge(langgraph.graph.START, "retrieve_docs") | ||||||
|     graph = graph_builder.compile() |     graph = graph_builder.compile() | ||||||
|  |  | ||||||
|     response = graph.invoke( |     response = graph.invoke( | ||||||
| @@ -246,6 +206,7 @@ def get_existing_story_ids() -> set[str]: | |||||||
|  |  | ||||||
| async def fetch_hn_top_stories( | async def fetch_hn_top_stories( | ||||||
|     limit: int = 10, |     limit: int = 10, | ||||||
|  |     force_fetch: bool = False, | ||||||
| ) -> list[langchain_core.documents.Document]: | ) -> list[langchain_core.documents.Document]: | ||||||
|     hn = HackerNewsClient() |     hn = HackerNewsClient() | ||||||
|     stories = await hn.get_top_stories(limit=limit) |     stories = await hn.get_top_stories(limit=limit) | ||||||
| @@ -256,11 +217,15 @@ async def fetch_hn_top_stories( | |||||||
|  |  | ||||||
|     new_stories = [story for story in stories if story.id not in existing_ids] |     new_stories = [story for story in stories if story.id not in existing_ids] | ||||||
|  |  | ||||||
|     print(f"Found {len(stories)} top stories, {len(new_stories)} are new") |     logging.info(f"Found {len(stories)} top stories, {len(new_stories)} are new") | ||||||
|  |  | ||||||
|     if not new_stories: |     if not new_stories: | ||||||
|         print("No new stories to fetch") |         if not force_fetch: | ||||||
|  |             logging.info("No new stories to fetch") | ||||||
|             return [] |             return [] | ||||||
|  |         else: | ||||||
|  |             logging.info("Force fetching all top stories regardless of existing IDs") | ||||||
|  |             new_stories = stories | ||||||
|  |  | ||||||
|     contents = {} |     contents = {} | ||||||
|  |  | ||||||
| @@ -268,24 +233,16 @@ async def fetch_hn_top_stories( | |||||||
|     scraper = JinaScraper(os.getenv("JINA_API_KEY")) |     scraper = JinaScraper(os.getenv("JINA_API_KEY")) | ||||||
|  |  | ||||||
|     async def _fetch_content(story: Story) -> tuple[str, str]: |     async def _fetch_content(story: Story) -> tuple[str, str]: | ||||||
|         try: |  | ||||||
|         if not story.url: |         if not story.url: | ||||||
|             return story.id, story.title |             return story.id, story.title | ||||||
|         return story.id, await scraper.get_content(story.url) |         return story.id, await scraper.get_content(story.url) | ||||||
|         except Exception as e: |  | ||||||
|             logging.warning(f"Failed to fetch content for story {story.id}: {e}") |  | ||||||
|             return story.id, story.title  # Fallback to title if content fetch fails |  | ||||||
|  |  | ||||||
|     tasks = [_fetch_content(story) for story in new_stories] |     tasks = [_fetch_content(story) for story in new_stories] | ||||||
|     results = await asyncio.gather(*tasks, return_exceptions=True) |     results = await asyncio.gather(*tasks) | ||||||
|  |  | ||||||
|     # Filter out exceptions and convert to dict |     # Filter out exceptions and convert to dict | ||||||
|     contents = {} |     contents = {} | ||||||
|     for result in results: |     for result in results: | ||||||
|         if isinstance(result, Exception): |  | ||||||
|             logging.error(f"Task failed with exception: {result}") |  | ||||||
|             continue |  | ||||||
|         if isinstance(result, tuple) and len(result) == 2: |  | ||||||
|         story_id, content = result |         story_id, content = result | ||||||
|         contents[story_id] = content |         contents[story_id] = content | ||||||
|  |  | ||||||
| @@ -324,31 +281,15 @@ async def main(): | |||||||
|         for story in new_stories: |         for story in new_stories: | ||||||
|             categories = categorize(story, llm) |             categories = categorize(story, llm) | ||||||
|             story.metadata["categories"] = list(categories) |             story.metadata["categories"] = list(categories) | ||||||
|             print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}") |             print( | ||||||
|  |                 f"Story ID {story.metadata["story_id"]} ({story.metadata["title"]}) categorized as: {categories}" | ||||||
|         # 2. Split |  | ||||||
|         documents_to_store = [] |  | ||||||
|         for story in new_stories: |  | ||||||
|             # If article is short enough, store as-is |  | ||||||
|             if len(story.page_content) <= 3000: |  | ||||||
|                 documents_to_store.append(story) |  | ||||||
|             else: |  | ||||||
|                 # For very long articles, chunk but keep story metadata |  | ||||||
|                 splitter = langchain.text_splitter.RecursiveCharacterTextSplitter( |  | ||||||
|                     chunk_size=2000, |  | ||||||
|                     chunk_overlap=200, |  | ||||||
|                     add_start_index=True, |  | ||||||
|             ) |             ) | ||||||
|                 chunks = splitter.split_documents([story]) |  | ||||||
|                 # Add chunk info to metadata |  | ||||||
|                 for i, chunk in enumerate(chunks): |  | ||||||
|                     chunk.metadata["chunk_index"] = i |  | ||||||
|                     chunk.metadata["total_chunks"] = len(chunks) |  | ||||||
|                 documents_to_store.extend(chunks) |  | ||||||
|  |  | ||||||
|         # 3. Store |         # 2. Split & 3. Store | ||||||
|         _ = vector_store.add_documents(documents_to_store) |         retriever.add_documents( | ||||||
|         print(f"Added {len(documents_to_store)} documents to vector store") |             new_stories, ids=[doc.metadata["story_id"] for doc in new_stories] | ||||||
|  |         ) | ||||||
|  |         print(f"Added {len(new_stories)} documents to document store") | ||||||
|     else: |     else: | ||||||
|         print("No new stories to process") |         print("No new stories to process") | ||||||
|  |  | ||||||
| @@ -361,8 +302,6 @@ async def main(): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import asyncio |  | ||||||
|  |  | ||||||
|     logging.basicConfig(level=logging.INFO) |     logging.basicConfig(level=logging.INFO) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|   | |||||||
							
								
								
									
										28
									
								
								scrape.py
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								scrape.py
									
									
									
									
									
								
							| @@ -1,4 +1,3 @@ | |||||||
| import asyncio |  | ||||||
| import logging | import logging | ||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| from typing import override | from typing import override | ||||||
| @@ -8,41 +7,22 @@ import httpx | |||||||
|  |  | ||||||
| class TextScraper(ABC): | class TextScraper(ABC): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self._client = httpx.AsyncClient(timeout=httpx.Timeout(5.0)) |         self._http_headers = {} | ||||||
|  |  | ||||||
|     async def _fetch_text(self, url: str) -> str: |     async def _fetch_text(self, url: str) -> str: | ||||||
|         """Fetch the raw HTML content from the URL.""" |         """Fetch the raw HTML content from the URL.""" | ||||||
|         response = None |  | ||||||
|         try: |         try: | ||||||
|             response = await self._client.get(url) |             async with httpx.AsyncClient(headers=self._http_headers) as client: | ||||||
|  |                 response = await client.get(url) | ||||||
|                 response.raise_for_status() |                 response.raise_for_status() | ||||||
|                 return response.text |                 return response.text | ||||||
|         except Exception: |         except Exception: | ||||||
|             logging.warning(f"Failed to fetch text from {url}", exc_info=True) |             logging.warning(f"Failed to fetch text from {url}", exc_info=True) | ||||||
|             raise |             raise | ||||||
|         finally: |  | ||||||
|             if response: |  | ||||||
|                 await response.aclose() |  | ||||||
|  |  | ||||||
|     @abstractmethod |     @abstractmethod | ||||||
|     async def get_content(self, url: str) -> str: ... |     async def get_content(self, url: str) -> str: ... | ||||||
|  |  | ||||||
|     async def close(self): |  | ||||||
|         """Close the underlying HTTP client.""" |  | ||||||
|         if self._client and not self._client.is_closed: |  | ||||||
|             await self._client.aclose() |  | ||||||
|  |  | ||||||
|     def __del__(self): |  | ||||||
|         """Ensure the HTTP client is closed when the object is deleted.""" |  | ||||||
|         try: |  | ||||||
|             loop = asyncio.get_event_loop() |  | ||||||
|             if loop.is_running(): |  | ||||||
|                 loop.create_task(self.close()) |  | ||||||
|             else: |  | ||||||
|                 loop.run_until_complete(self.close()) |  | ||||||
|         except Exception: |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Html2textScraper(TextScraper): | class Html2textScraper(TextScraper): | ||||||
|     @override |     @override | ||||||
| @@ -65,7 +45,7 @@ class JinaScraper(TextScraper): | |||||||
|     def __init__(self, api_key: str | None = None): |     def __init__(self, api_key: str | None = None): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         if api_key: |         if api_key: | ||||||
|             self._client.headers.update({"Authorization": f"Bearer {api_key}"}) |             self._http_headers.update({"Authorization": f"Bearer {api_key}"}) | ||||||
|  |  | ||||||
|     @override |     @override | ||||||
|     async def get_content(self, url: str) -> str: |     async def get_content(self, url: str) -> str: | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								slack.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								slack.py
									
									
									
									
									
								
							| @@ -77,6 +77,7 @@ def send_message(channel: str, blocks: list) -> None: | |||||||
|             text="Tech updates", |             text="Tech updates", | ||||||
|             blocks=blocks, |             blocks=blocks, | ||||||
|             unfurl_links=False, |             unfurl_links=False, | ||||||
|  |             unfurl_media=False, | ||||||
|         ) |         ) | ||||||
|         response.validate() |         response.validate() | ||||||
|         logging.info(f"Message sent successfully to channel {channel}") |         logging.info(f"Message sent successfully to channel {channel}") | ||||||
|   | |||||||
							
								
								
									
										48
									
								
								util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								util.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | import json | ||||||
|  | from collections.abc import Iterator | ||||||
|  |  | ||||||
|  | from langchain.schema import BaseStore, Document | ||||||
|  | from langchain.storage import LocalFileStore | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DocumentLocalFileStore(BaseStore[str, Document]): | ||||||
|  |     def __init__(self, root_path: str): | ||||||
|  |         self.store = LocalFileStore(root_path) | ||||||
|  |  | ||||||
|  |     def mget(self, keys): | ||||||
|  |         """Get multiple documents by keys""" | ||||||
|  |         results = [] | ||||||
|  |         for key in keys: | ||||||
|  |             try: | ||||||
|  |                 serialized = self.store.mget([key])[0] | ||||||
|  |                 if serialized: | ||||||
|  |                     # Deserialize the document | ||||||
|  |                     doc_dict = json.loads(serialized.decode("utf-8")) | ||||||
|  |                     doc = Document( | ||||||
|  |                         page_content=doc_dict["page_content"], | ||||||
|  |                         metadata=doc_dict.get("metadata", {}), | ||||||
|  |                     ) | ||||||
|  |                     results.append(doc) | ||||||
|  |                 else: | ||||||
|  |                     results.append(None) | ||||||
|  |             except: | ||||||
|  |                 results.append(None) | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def mset(self, key_value_pairs): | ||||||
|  |         """Set multiple key-value pairs""" | ||||||
|  |         serialized_pairs = [] | ||||||
|  |         for key, doc in key_value_pairs: | ||||||
|  |             # Serialize the document | ||||||
|  |             doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata} | ||||||
|  |             serialized = json.dumps(doc_dict).encode("utf-8") | ||||||
|  |             serialized_pairs.append((key, serialized)) | ||||||
|  |  | ||||||
|  |         self.store.mset(serialized_pairs) | ||||||
|  |  | ||||||
|  |     def mdelete(self, keys): | ||||||
|  |         """Delete multiple keys""" | ||||||
|  |         self.store.mdelete(keys) | ||||||
|  |  | ||||||
|  |     def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: | ||||||
|  |         return self.store.yield_keys(prefix=prefix) | ||||||
		Reference in New Issue
	
	Block a user