372 lines
13 KiB
Python
372 lines
13 KiB
Python
import logging
|
|
import os
|
|
from typing import Iterable, TypedDict
|
|
|
|
import langchain
|
|
import langchain.chat_models
|
|
import langchain.prompts
|
|
import langchain.text_splitter
|
|
import langchain_core
|
|
import langchain_core.documents
|
|
import langchain_openai
|
|
import langchain_weaviate
|
|
import langgraph.graph
|
|
|
|
import slack
|
|
import weaviate
|
|
from hn import HackerNewsClient, Story
|
|
from scrape import JinaScraper
|
|
|
|
NUM_STORIES = 20
|
|
USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"]
|
|
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
|
|
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing
|
|
|
|
|
|
llm = langchain.chat_models.init_chat_model(
|
|
model="gpt-4o-mini", model_provider="openai"
|
|
)
|
|
embeddings = langchain_openai.OpenAIEmbeddings(model="text-embedding-3-large")
|
|
|
|
weaviate_client = weaviate.connect_to_local()
|
|
vector_store = langchain_weaviate.WeaviateVectorStore(
|
|
weaviate_client,
|
|
index_name="hn_stories",
|
|
text_key="page_content",
|
|
embedding=embeddings,
|
|
)
|
|
|
|
|
|
class State(TypedDict):
|
|
preferences: Iterable[str]
|
|
context: list[langchain_core.documents.Document]
|
|
answer: str
|
|
summaries: list[dict]
|
|
|
|
|
|
def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State:
|
|
# Search for relevant documents (with scores if available)
|
|
retrieved_docs = vector_store.similarity_search(
|
|
"Show the most interesting articles about the following topics: "
|
|
+ ", ".join(state["preferences"]),
|
|
k=top_n * 20, # Chunks, not complete stories
|
|
return_score=True
|
|
if hasattr(vector_store, "similarity_search_with_score")
|
|
else False,
|
|
)
|
|
|
|
# If scores are returned, unpack (doc, score) tuples; else, set score to None
|
|
docs_with_scores = []
|
|
if retrieved_docs and isinstance(retrieved_docs[0], tuple):
|
|
for doc, score in retrieved_docs:
|
|
docs_with_scores.append((doc, score))
|
|
else:
|
|
for doc in retrieved_docs:
|
|
docs_with_scores.append((doc, None))
|
|
|
|
# Group chunks by story_id and collect their scores
|
|
story_groups = {}
|
|
for doc, score in docs_with_scores:
|
|
story_id = doc.metadata.get("story_id")
|
|
if story_id not in story_groups:
|
|
story_groups[story_id] = []
|
|
story_groups[story_id].append((doc, score))
|
|
|
|
# Aggregate max score per story and reconstruct complete stories
|
|
story_scores = {}
|
|
complete_stories = []
|
|
for story_id, chunks_scores in story_groups.items():
|
|
chunks = [doc for doc, _ in chunks_scores]
|
|
scores = [s for _, s in chunks_scores if s is not None]
|
|
max_score = max(scores) if scores else None
|
|
story_scores[story_id] = max_score
|
|
if len(chunks) == 1:
|
|
complete_stories.append((chunks[0], max_score))
|
|
else:
|
|
combined_content = "\n\n".join(
|
|
chunk.page_content
|
|
for chunk in sorted(
|
|
chunks, key=lambda x: x.metadata.get("chunk_index", 0)
|
|
)
|
|
)
|
|
complete_story = langchain_core.documents.Document(
|
|
page_content=combined_content,
|
|
metadata=chunks[0].metadata, # Use metadata from first chunk
|
|
)
|
|
complete_stories.append((complete_story, max_score))
|
|
|
|
# Sort stories by max_score descending (None scores go last)
|
|
complete_stories_sorted = sorted(
|
|
complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True
|
|
)
|
|
|
|
# Return top_n stories
|
|
top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]]
|
|
return {
|
|
"preferences": state["preferences"],
|
|
"context": top_stories,
|
|
"answer": state.get("answer", ""),
|
|
"summaries": state.get("summaries", []),
|
|
}
|
|
|
|
|
|
def generate_structured_summaries(state: State):
|
|
"""Generate structured summaries for each story individually."""
|
|
summaries = []
|
|
|
|
for doc in state["context"]:
|
|
# Create a prompt for individual story summarization
|
|
prompt = langchain.prompts.PromptTemplate(
|
|
input_variables=["preferences", "title", "content", "source", "categories"],
|
|
template=(
|
|
"You are a helpful assistant that summarizes technology articles.\n\n"
|
|
"User preferences: {preferences}\n\n"
|
|
"Article title: {title}\n"
|
|
"Article categories: {categories}\n"
|
|
"Article content: {content}\n"
|
|
"Source URL: {source}\n\n"
|
|
"Use an informative but not too formal tone.\n"
|
|
"Please provide:\n"
|
|
"1. A concise summary (around 50 words) that highlights the key insights from the article.\n"
|
|
"2. The single user preference that this article best matches (or 'Other' if none match well)\n\n"
|
|
"Format your response as:\n"
|
|
"PREFERENCE: [preference name or 'Other']\n"
|
|
"SUMMARY: [your summary here]\n"
|
|
),
|
|
)
|
|
|
|
messages = prompt.invoke(
|
|
{
|
|
"preferences": ", ".join(state["preferences"]),
|
|
"title": doc.metadata.get("title", "Unknown Title"),
|
|
"content": doc.page_content[:5000], # Limit content length for LLM
|
|
"source": doc.metadata.get("source", ""),
|
|
"categories": ", ".join(doc.metadata.get("categories", [])),
|
|
}
|
|
)
|
|
|
|
response = llm.invoke(messages).content
|
|
|
|
# Parse the LLM response to extract preference and summary
|
|
response_text = response if isinstance(response, str) else str(response)
|
|
lines = response_text.strip().split("\n")
|
|
matched_preference = "Other"
|
|
summary_text = response_text
|
|
|
|
for line in lines:
|
|
if line.startswith("PREFERENCE:"):
|
|
matched_preference = line.replace("PREFERENCE:", "").strip()
|
|
elif line.startswith("SUMMARY:"):
|
|
summary_text = line.replace("SUMMARY:", "").strip()
|
|
|
|
# If we didn't find the structured format, use the whole response as summary
|
|
if not any(line.startswith("SUMMARY:") for line in lines):
|
|
summary_text = response_text.strip()
|
|
|
|
summaries.append(
|
|
{
|
|
"title": doc.metadata.get("title", "Unknown Title"),
|
|
"summary": summary_text,
|
|
"source_url": doc.metadata.get("source", ""),
|
|
"categories": doc.metadata.get("categories", []),
|
|
"story_id": doc.metadata.get("story_id"),
|
|
"matched_preference": matched_preference,
|
|
}
|
|
)
|
|
|
|
return {"summaries": summaries}
|
|
|
|
|
|
def group_stories_by_preference(
|
|
summaries: list[dict], preferences: list[str]
|
|
) -> dict[str, list[dict]]:
|
|
"""Group stories by their matched preferences in the order of user preferences."""
|
|
preference_groups = {}
|
|
|
|
# Group stories by the LLM-determined preference matching
|
|
for summary in summaries:
|
|
matched_preference = summary.get("matched_preference", "Other")
|
|
|
|
if matched_preference not in preference_groups:
|
|
preference_groups[matched_preference] = []
|
|
preference_groups[matched_preference].append(summary)
|
|
|
|
# Create ordered groups based on user preferences
|
|
ordered_groups = {}
|
|
|
|
# Add groups for user preferences in order
|
|
for preference in preferences:
|
|
if preference in preference_groups:
|
|
ordered_groups[preference] = preference_groups[preference]
|
|
|
|
# Add "Other" group at the end if it exists
|
|
if "Other" in preference_groups:
|
|
ordered_groups["Other"] = preference_groups["Other"]
|
|
|
|
return ordered_groups
|
|
|
|
|
|
def create_slack_blocks(summaries: list[dict], preferences: list[str]) -> list[dict]:
|
|
"""Convert structured summaries into Slack block format grouped by user preferences."""
|
|
grouped_stories = group_stories_by_preference(summaries, preferences)
|
|
return slack.format_slack_blocks(grouped_stories)
|
|
|
|
|
|
def run_structured_query(
|
|
preferences: Iterable[str],
|
|
) -> list[dict]:
|
|
"""Run query and return structured summary data."""
|
|
graph_builder = langgraph.graph.StateGraph(State).add_sequence(
|
|
[retrieve, generate_structured_summaries]
|
|
)
|
|
graph_builder.add_edge(langgraph.graph.START, "retrieve")
|
|
graph = graph_builder.compile()
|
|
|
|
response = graph.invoke(
|
|
State(preferences=preferences, context=[], answer="", summaries=[])
|
|
)
|
|
summaries = response["summaries"]
|
|
return summaries
|
|
|
|
|
|
def get_existing_story_ids() -> set[str]:
|
|
"""Get the IDs of stories that already exist in the vector store."""
|
|
try:
|
|
collection = vector_store._collection
|
|
existing_ids = set()
|
|
for doc in collection.iterator():
|
|
story_id = doc.properties.get("story_id")
|
|
if story_id:
|
|
existing_ids.add(story_id)
|
|
return existing_ids
|
|
except Exception:
|
|
logging.warning("Could not retrieve existing story IDs", exc_info=True)
|
|
return set()
|
|
|
|
|
|
async def fetch_hn_top_stories(
|
|
limit: int = 10,
|
|
) -> list[langchain_core.documents.Document]:
|
|
hn = HackerNewsClient()
|
|
stories = await hn.get_top_stories(limit=limit)
|
|
|
|
# Get existing story IDs to avoid re-fetching
|
|
existing_ids = get_existing_story_ids()
|
|
logging.info(f"Existing story IDs: {len(existing_ids)} found in vector store")
|
|
|
|
new_stories = [story for story in stories if story.id not in existing_ids]
|
|
|
|
print(f"Found {len(stories)} top stories, {len(new_stories)} are new")
|
|
|
|
if not new_stories:
|
|
print("No new stories to fetch")
|
|
return []
|
|
|
|
contents = {}
|
|
|
|
# Fetch content for each new story asynchronously
|
|
scraper = JinaScraper(os.getenv("JINA_API_KEY"))
|
|
|
|
async def _fetch_content(story: Story) -> tuple[str, str]:
|
|
try:
|
|
if not story.url:
|
|
return story.id, story.title
|
|
return story.id, await scraper.get_content(story.url)
|
|
except Exception as e:
|
|
logging.warning(f"Failed to fetch content for story {story.id}: {e}")
|
|
return story.id, story.title # Fallback to title if content fetch fails
|
|
|
|
tasks = [_fetch_content(story) for story in new_stories]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out exceptions and convert to dict
|
|
contents = {}
|
|
for result in results:
|
|
if isinstance(result, Exception):
|
|
logging.error(f"Task failed with exception: {result}")
|
|
continue
|
|
if isinstance(result, tuple) and len(result) == 2:
|
|
story_id, content = result
|
|
contents[story_id] = content
|
|
|
|
documents = [
|
|
langchain_core.documents.Document(
|
|
page_content=contents[story.id],
|
|
metadata={
|
|
"story_id": story.id,
|
|
"title": story.title,
|
|
"source": story.url,
|
|
"created_at": story.created_at.isoformat(),
|
|
},
|
|
)
|
|
for story in new_stories
|
|
]
|
|
return documents
|
|
|
|
|
|
async def main():
|
|
if ENABLE_MLFLOW_TRACING:
|
|
import mlflow
|
|
|
|
mlflow.set_tracking_uri("http://localhost:5000")
|
|
mlflow.set_experiment("langchain-rag-hn")
|
|
mlflow.langchain.autolog()
|
|
|
|
# 1. Load only new stories
|
|
new_stories = await fetch_hn_top_stories(limit=NUM_STORIES)
|
|
|
|
if new_stories:
|
|
print(f"Processing {len(new_stories)} new stories")
|
|
|
|
# Categorize stories (optional)
|
|
from classify import categorize
|
|
|
|
for story in new_stories:
|
|
categories = categorize(story, llm)
|
|
story.metadata["categories"] = list(categories)
|
|
print(f"Story ID {story.metadata["story_id"]} categorized as: {categories}")
|
|
|
|
# 2. Split
|
|
documents_to_store = []
|
|
for story in new_stories:
|
|
# If article is short enough, store as-is
|
|
if len(story.page_content) <= 3000:
|
|
documents_to_store.append(story)
|
|
else:
|
|
# For very long articles, chunk but keep story metadata
|
|
splitter = langchain.text_splitter.RecursiveCharacterTextSplitter(
|
|
chunk_size=2000,
|
|
chunk_overlap=200,
|
|
add_start_index=True,
|
|
)
|
|
chunks = splitter.split_documents([story])
|
|
# Add chunk info to metadata
|
|
for i, chunk in enumerate(chunks):
|
|
chunk.metadata["chunk_index"] = i
|
|
chunk.metadata["total_chunks"] = len(chunks)
|
|
documents_to_store.extend(chunks)
|
|
|
|
# 3. Store
|
|
_ = vector_store.add_documents(documents_to_store)
|
|
print(f"Added {len(documents_to_store)} documents to vector store")
|
|
else:
|
|
print("No new stories to process")
|
|
|
|
# 4. Query
|
|
summaries = run_structured_query(USER_PREFERENCES)
|
|
if ENABLE_SLACK:
|
|
blocks = create_slack_blocks(summaries, USER_PREFERENCES)
|
|
slack.send_message(channel="#ragpull-demo", blocks=blocks)
|
|
print(summaries)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
try:
|
|
asyncio.run(main())
|
|
finally:
|
|
weaviate_client.close()
|