Compare commits
2 Commits
87a17331fd
...
311c332b10
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
311c332b10 | ||
|
|
259c9699ad |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -15,3 +15,5 @@ mlartifacts/
|
|||||||
|
|
||||||
# Weaviate vector store
|
# Weaviate vector store
|
||||||
weaviate/
|
weaviate/
|
||||||
|
|
||||||
|
data/
|
||||||
159
indexing.py
159
indexing.py
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Iterable, TypedDict
|
from typing import Iterable, TypedDict
|
||||||
@@ -5,6 +6,7 @@ from typing import Iterable, TypedDict
|
|||||||
import langchain
|
import langchain
|
||||||
import langchain.chat_models
|
import langchain.chat_models
|
||||||
import langchain.prompts
|
import langchain.prompts
|
||||||
|
import langchain.retrievers
|
||||||
import langchain.text_splitter
|
import langchain.text_splitter
|
||||||
import langchain_core
|
import langchain_core
|
||||||
import langchain_core.documents
|
import langchain_core.documents
|
||||||
@@ -16,11 +18,12 @@ import slack
|
|||||||
import weaviate
|
import weaviate
|
||||||
from hn import HackerNewsClient, Story
|
from hn import HackerNewsClient, Story
|
||||||
from scrape import JinaScraper
|
from scrape import JinaScraper
|
||||||
|
from util import DocumentLocalFileStore
|
||||||
|
|
||||||
NUM_STORIES = 20
|
NUM_STORIES = 40 # Number of top stories to fetch from Hacker News
|
||||||
USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"]
|
USER_PREFERENCES = ["Machine Learning", "Programming", "DevOps"]
|
||||||
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
|
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
|
||||||
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing
|
ENABLE_MLFLOW_TRACING = True # Use MLflow (at http://localhost:5000) for tracing
|
||||||
|
|
||||||
|
|
||||||
llm = langchain.chat_models.init_chat_model(
|
llm = langchain.chat_models.init_chat_model(
|
||||||
@@ -36,6 +39,18 @@ vector_store = langchain_weaviate.WeaviateVectorStore(
|
|||||||
embedding=embeddings,
|
embedding=embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=2000,
|
||||||
|
chunk_overlap=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_store = DocumentLocalFileStore(root_path="data/documents")
|
||||||
|
retriever = langchain.retrievers.ParentDocumentRetriever(
|
||||||
|
vectorstore=vector_store,
|
||||||
|
docstore=doc_store,
|
||||||
|
child_splitter=splitter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
preferences: Iterable[str]
|
preferences: Iterable[str]
|
||||||
@@ -44,70 +59,15 @@ class State(TypedDict):
|
|||||||
summaries: list[dict]
|
summaries: list[dict]
|
||||||
|
|
||||||
|
|
||||||
def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State:
|
def retrieve_docs(state: State, top_n: int = 3):
|
||||||
# Search for relevant documents (with scores if available)
|
"""Retrieve relevant documents based on user preferences."""
|
||||||
retrieved_docs = vector_store.similarity_search(
|
|
||||||
"Show the most interesting articles about the following topics: "
|
|
||||||
+ ", ".join(state["preferences"]),
|
|
||||||
k=top_n * 20, # Chunks, not complete stories
|
|
||||||
return_score=True
|
|
||||||
if hasattr(vector_store, "similarity_search_with_score")
|
|
||||||
else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If scores are returned, unpack (doc, score) tuples; else, set score to None
|
docs = []
|
||||||
docs_with_scores = []
|
for preference in state["preferences"]:
|
||||||
if retrieved_docs and isinstance(retrieved_docs[0], tuple):
|
logging.info(f"Retrieving documents for preference: {preference}")
|
||||||
for doc, score in retrieved_docs:
|
docs.extend(retriever.invoke(preference)[:top_n])
|
||||||
docs_with_scores.append((doc, score))
|
|
||||||
else:
|
|
||||||
for doc in retrieved_docs:
|
|
||||||
docs_with_scores.append((doc, None))
|
|
||||||
|
|
||||||
# Group chunks by story_id and collect their scores
|
return {"context": docs}
|
||||||
story_groups = {}
|
|
||||||
for doc, score in docs_with_scores:
|
|
||||||
story_id = doc.metadata.get("story_id")
|
|
||||||
if story_id not in story_groups:
|
|
||||||
story_groups[story_id] = []
|
|
||||||
story_groups[story_id].append((doc, score))
|
|
||||||
|
|
||||||
# Aggregate max score per story and reconstruct complete stories
|
|
||||||
story_scores = {}
|
|
||||||
complete_stories = []
|
|
||||||
for story_id, chunks_scores in story_groups.items():
|
|
||||||
chunks = [doc for doc, _ in chunks_scores]
|
|
||||||
scores = [s for _, s in chunks_scores if s is not None]
|
|
||||||
max_score = max(scores) if scores else None
|
|
||||||
story_scores[story_id] = max_score
|
|
||||||
if len(chunks) == 1:
|
|
||||||
complete_stories.append((chunks[0], max_score))
|
|
||||||
else:
|
|
||||||
combined_content = "\n\n".join(
|
|
||||||
chunk.page_content
|
|
||||||
for chunk in sorted(
|
|
||||||
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
complete_story = langchain_core.documents.Document(
|
|
||||||
page_content=combined_content,
|
|
||||||
metadata=chunks[0].metadata, # Use metadata from first chunk
|
|
||||||
)
|
|
||||||
complete_stories.append((complete_story, max_score))
|
|
||||||
|
|
||||||
# Sort stories by max_score descending (None scores go last)
|
|
||||||
complete_stories_sorted = sorted(
|
|
||||||
complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return top_n stories
|
|
||||||
top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]]
|
|
||||||
return {
|
|
||||||
"preferences": state["preferences"],
|
|
||||||
"context": top_stories,
|
|
||||||
"answer": state.get("answer", ""),
|
|
||||||
"summaries": state.get("summaries", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def generate_structured_summaries(state: State):
|
def generate_structured_summaries(state: State):
|
||||||
@@ -217,9 +177,9 @@ def run_structured_query(
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Run query and return structured summary data."""
|
"""Run query and return structured summary data."""
|
||||||
graph_builder = langgraph.graph.StateGraph(State).add_sequence(
|
graph_builder = langgraph.graph.StateGraph(State).add_sequence(
|
||||||
[retrieve, generate_structured_summaries]
|
[retrieve_docs, generate_structured_summaries]
|
||||||
)
|
)
|
||||||
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
graph_builder.add_edge(langgraph.graph.START, "retrieve_docs")
|
||||||
graph = graph_builder.compile()
|
graph = graph_builder.compile()
|
||||||
|
|
||||||
response = graph.invoke(
|
response = graph.invoke(
|
||||||
@@ -246,6 +206,7 @@ def get_existing_story_ids() -> set[str]:
|
|||||||
|
|
||||||
async def fetch_hn_top_stories(
|
async def fetch_hn_top_stories(
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
|
force_fetch: bool = False,
|
||||||
) -> list[langchain_core.documents.Document]:
|
) -> list[langchain_core.documents.Document]:
|
||||||
hn = HackerNewsClient()
|
hn = HackerNewsClient()
|
||||||
stories = await hn.get_top_stories(limit=limit)
|
stories = await hn.get_top_stories(limit=limit)
|
||||||
@@ -256,11 +217,15 @@ async def fetch_hn_top_stories(
|
|||||||
|
|
||||||
new_stories = [story for story in stories if story.id not in existing_ids]
|
new_stories = [story for story in stories if story.id not in existing_ids]
|
||||||
|
|
||||||
print(f"Found {len(stories)} top stories, {len(new_stories)} are new")
|
logging.info(f"Found {len(stories)} top stories, {len(new_stories)} are new")
|
||||||
|
|
||||||
if not new_stories:
|
if not new_stories:
|
||||||
print("No new stories to fetch")
|
if not force_fetch:
|
||||||
return []
|
logging.info("No new stories to fetch")
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
logging.info("Force fetching all top stories regardless of existing IDs")
|
||||||
|
new_stories = stories
|
||||||
|
|
||||||
contents = {}
|
contents = {}
|
||||||
|
|
||||||
@@ -268,26 +233,18 @@ async def fetch_hn_top_stories(
|
|||||||
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
|
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
|
||||||
|
|
||||||
async def _fetch_content(story: Story) -> tuple[str, str]:
|
async def _fetch_content(story: Story) -> tuple[str, str]:
|
||||||
try:
|
if not story.url:
|
||||||
if not story.url:
|
return story.id, story.title
|
||||||
return story.id, story.title
|
return story.id, await scraper.get_content(story.url)
|
||||||
return story.id, await scraper.get_content(story.url)
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
|
|
||||||
return story.id, story.title # Fallback to title if content fetch fails
|
|
||||||
|
|
||||||
tasks = [_fetch_content(story) for story in new_stories]
|
tasks = [_fetch_content(story) for story in new_stories]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Filter out exceptions and convert to dict
|
# Filter out exceptions and convert to dict
|
||||||
contents = {}
|
contents = {}
|
||||||
for result in results:
|
for result in results:
|
||||||
if isinstance(result, Exception):
|
story_id, content = result
|
||||||
logging.error(f"Task failed with exception: {result}")
|
contents[story_id] = content
|
||||||
continue
|
|
||||||
if isinstance(result, tuple) and len(result) == 2:
|
|
||||||
story_id, content = result
|
|
||||||
contents[story_id] = content
|
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
langchain_core.documents.Document(
|
langchain_core.documents.Document(
|
||||||
@@ -324,31 +281,15 @@ async def main():
|
|||||||
for story in new_stories:
|
for story in new_stories:
|
||||||
categories = categorize(story, llm)
|
categories = categorize(story, llm)
|
||||||
story.metadata["categories"] = list(categories)
|
story.metadata["categories"] = list(categories)
|
||||||
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}")
|
print(
|
||||||
|
f"Story ID {story.metadata["story_id"]} ({story.metadata["title"]}) categorized as: {categories}"
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Split
|
# 2. Split & 3. Store
|
||||||
documents_to_store = []
|
retriever.add_documents(
|
||||||
for story in new_stories:
|
new_stories, ids=[doc.metadata["story_id"] for doc in new_stories]
|
||||||
# If article is short enough, store as-is
|
)
|
||||||
if len(story.page_content) <= 3000:
|
print(f"Added {len(new_stories)} documents to document store")
|
||||||
documents_to_store.append(story)
|
|
||||||
else:
|
|
||||||
# For very long articles, chunk but keep story metadata
|
|
||||||
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=2000,
|
|
||||||
chunk_overlap=200,
|
|
||||||
add_start_index=True,
|
|
||||||
)
|
|
||||||
chunks = splitter.split_documents([story])
|
|
||||||
# Add chunk info to metadata
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
chunk.metadata["chunk_index"] = i
|
|
||||||
chunk.metadata["total_chunks"] = len(chunks)
|
|
||||||
documents_to_store.extend(chunks)
|
|
||||||
|
|
||||||
# 3. Store
|
|
||||||
_ = vector_store.add_documents(documents_to_store)
|
|
||||||
print(f"Added {len(documents_to_store)} documents to vector store")
|
|
||||||
else:
|
else:
|
||||||
print("No new stories to process")
|
print("No new stories to process")
|
||||||
|
|
||||||
@@ -361,8 +302,6 @@ async def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
32
scrape.py
32
scrape.py
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -8,41 +7,22 @@ import httpx
|
|||||||
|
|
||||||
class TextScraper(ABC):
|
class TextScraper(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._client = httpx.AsyncClient(timeout=httpx.Timeout(5.0))
|
self._http_headers = {}
|
||||||
|
|
||||||
async def _fetch_text(self, url: str) -> str:
|
async def _fetch_text(self, url: str) -> str:
|
||||||
"""Fetch the raw HTML content from the URL."""
|
"""Fetch the raw HTML content from the URL."""
|
||||||
response = None
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.get(url)
|
async with httpx.AsyncClient(headers=self._http_headers) as client:
|
||||||
response.raise_for_status()
|
response = await client.get(url)
|
||||||
return response.text
|
response.raise_for_status()
|
||||||
|
return response.text
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warning(f"Failed to fetch text from {url}", exc_info=True)
|
logging.warning(f"Failed to fetch text from {url}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
if response:
|
|
||||||
await response.aclose()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_content(self, url: str) -> str: ...
|
async def get_content(self, url: str) -> str: ...
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Close the underlying HTTP client."""
|
|
||||||
if self._client and not self._client.is_closed:
|
|
||||||
await self._client.aclose()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""Ensure the HTTP client is closed when the object is deleted."""
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_running():
|
|
||||||
loop.create_task(self.close())
|
|
||||||
else:
|
|
||||||
loop.run_until_complete(self.close())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Html2textScraper(TextScraper):
|
class Html2textScraper(TextScraper):
|
||||||
@override
|
@override
|
||||||
@@ -65,7 +45,7 @@ class JinaScraper(TextScraper):
|
|||||||
def __init__(self, api_key: str | None = None):
|
def __init__(self, api_key: str | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if api_key:
|
if api_key:
|
||||||
self._client.headers.update({"Authorization": f"Bearer {api_key}"})
|
self._http_headers.update({"Authorization": f"Bearer {api_key}"})
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_content(self, url: str) -> str:
|
async def get_content(self, url: str) -> str:
|
||||||
|
|||||||
1
slack.py
1
slack.py
@@ -77,6 +77,7 @@ def send_message(channel: str, blocks: list) -> None:
|
|||||||
text="Tech updates",
|
text="Tech updates",
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
unfurl_links=False,
|
unfurl_links=False,
|
||||||
|
unfurl_media=False,
|
||||||
)
|
)
|
||||||
response.validate()
|
response.validate()
|
||||||
logging.info(f"Message sent successfully to channel {channel}")
|
logging.info(f"Message sent successfully to channel {channel}")
|
||||||
|
|||||||
48
util.py
Normal file
48
util.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from langchain.schema import BaseStore, Document
|
||||||
|
from langchain.storage import LocalFileStore
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentLocalFileStore(BaseStore[str, Document]):
|
||||||
|
def __init__(self, root_path: str):
|
||||||
|
self.store = LocalFileStore(root_path)
|
||||||
|
|
||||||
|
def mget(self, keys):
|
||||||
|
"""Get multiple documents by keys"""
|
||||||
|
results = []
|
||||||
|
for key in keys:
|
||||||
|
try:
|
||||||
|
serialized = self.store.mget([key])[0]
|
||||||
|
if serialized:
|
||||||
|
# Deserialize the document
|
||||||
|
doc_dict = json.loads(serialized.decode("utf-8"))
|
||||||
|
doc = Document(
|
||||||
|
page_content=doc_dict["page_content"],
|
||||||
|
metadata=doc_dict.get("metadata", {}),
|
||||||
|
)
|
||||||
|
results.append(doc)
|
||||||
|
else:
|
||||||
|
results.append(None)
|
||||||
|
except:
|
||||||
|
results.append(None)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def mset(self, key_value_pairs):
|
||||||
|
"""Set multiple key-value pairs"""
|
||||||
|
serialized_pairs = []
|
||||||
|
for key, doc in key_value_pairs:
|
||||||
|
# Serialize the document
|
||||||
|
doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata}
|
||||||
|
serialized = json.dumps(doc_dict).encode("utf-8")
|
||||||
|
serialized_pairs.append((key, serialized))
|
||||||
|
|
||||||
|
self.store.mset(serialized_pairs)
|
||||||
|
|
||||||
|
def mdelete(self, keys):
|
||||||
|
"""Delete multiple keys"""
|
||||||
|
self.store.mdelete(keys)
|
||||||
|
|
||||||
|
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
||||||
|
return self.store.yield_keys(prefix=prefix)
|
||||||
Reference in New Issue
Block a user