Improved retrieval step with relevance ranking
This commit is contained in:
56
indexing.py
56
indexing.py
@@ -18,7 +18,7 @@ from hn import HackerNewsClient, Story
|
|||||||
from scrape import JinaScraper
|
from scrape import JinaScraper
|
||||||
|
|
||||||
NUM_STORIES = 20
|
NUM_STORIES = 20
|
||||||
USER_PREFERENCES = ["Machine Learning", "Linux", "Open-Source"]
|
USER_PREFERENCES = ["Machine Learning", "Programming", "Robotics"]
|
||||||
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
|
ENABLE_SLACK = True # Send updates to Slack, need to set SLACK_BOT_TOKEN env var
|
||||||
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing
|
ENABLE_MLFLOW_TRACING = False # Use MLflow (at http://localhost:5000) for tracing
|
||||||
|
|
||||||
@@ -44,27 +44,45 @@ class State(TypedDict):
|
|||||||
summaries: list[dict]
|
summaries: list[dict]
|
||||||
|
|
||||||
|
|
||||||
def retrieve(state: State, top_n: int = 5) -> State:
|
def retrieve(state: State, top_n: int = 2 * len(USER_PREFERENCES)) -> State:
|
||||||
# Search for relevant documents
|
# Search for relevant documents (with scores if available)
|
||||||
retrieved_docs = vector_store.similarity_search(
|
retrieved_docs = vector_store.similarity_search(
|
||||||
"Categories: " + ", ".join(state["preferences"]), k=20
|
"Show the most interesting articles about the following topics: "
|
||||||
|
+ ", ".join(state["preferences"]),
|
||||||
|
k=top_n * 20, # Chunks, not complete stories
|
||||||
|
return_score=True
|
||||||
|
if hasattr(vector_store, "similarity_search_with_score")
|
||||||
|
else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If you're using chunks, group them back into complete stories
|
# If scores are returned, unpack (doc, score) tuples; else, set score to None
|
||||||
story_groups = {}
|
docs_with_scores = []
|
||||||
|
if retrieved_docs and isinstance(retrieved_docs[0], tuple):
|
||||||
|
for doc, score in retrieved_docs:
|
||||||
|
docs_with_scores.append((doc, score))
|
||||||
|
else:
|
||||||
for doc in retrieved_docs:
|
for doc in retrieved_docs:
|
||||||
|
docs_with_scores.append((doc, None))
|
||||||
|
|
||||||
|
# Group chunks by story_id and collect their scores
|
||||||
|
story_groups = {}
|
||||||
|
for doc, score in docs_with_scores:
|
||||||
story_id = doc.metadata.get("story_id")
|
story_id = doc.metadata.get("story_id")
|
||||||
if story_id not in story_groups:
|
if story_id not in story_groups:
|
||||||
story_groups[story_id] = []
|
story_groups[story_id] = []
|
||||||
story_groups[story_id].append(doc)
|
story_groups[story_id].append((doc, score))
|
||||||
|
|
||||||
# Reconstruct complete stories or use the best chunk per story
|
# Aggregate max score per story and reconstruct complete stories
|
||||||
|
story_scores = {}
|
||||||
complete_stories = []
|
complete_stories = []
|
||||||
for story_id, chunks in story_groups.items():
|
for story_id, chunks_scores in story_groups.items():
|
||||||
|
chunks = [doc for doc, _ in chunks_scores]
|
||||||
|
scores = [s for _, s in chunks_scores if s is not None]
|
||||||
|
max_score = max(scores) if scores else None
|
||||||
|
story_scores[story_id] = max_score
|
||||||
if len(chunks) == 1:
|
if len(chunks) == 1:
|
||||||
complete_stories.append(chunks[0])
|
complete_stories.append((chunks[0], max_score))
|
||||||
else:
|
else:
|
||||||
# Combine chunks back into complete story
|
|
||||||
combined_content = "\n\n".join(
|
combined_content = "\n\n".join(
|
||||||
chunk.page_content
|
chunk.page_content
|
||||||
for chunk in sorted(
|
for chunk in sorted(
|
||||||
@@ -75,9 +93,21 @@ def retrieve(state: State, top_n: int = 5) -> State:
|
|||||||
page_content=combined_content,
|
page_content=combined_content,
|
||||||
metadata=chunks[0].metadata, # Use metadata from first chunk
|
metadata=chunks[0].metadata, # Use metadata from first chunk
|
||||||
)
|
)
|
||||||
complete_stories.append(complete_story)
|
complete_stories.append((complete_story, max_score))
|
||||||
|
|
||||||
return {"context": complete_stories[:top_n]}
|
# Sort stories by max_score descending (None scores go last)
|
||||||
|
complete_stories_sorted = sorted(
|
||||||
|
complete_stories, key=lambda x: (x[1] is not None, x[1]), reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return top_n stories
|
||||||
|
top_stories = [doc for doc, _ in complete_stories_sorted[:top_n]]
|
||||||
|
return {
|
||||||
|
"preferences": state["preferences"],
|
||||||
|
"context": top_stories,
|
||||||
|
"answer": state.get("answer", ""),
|
||||||
|
"summaries": state.get("summaries", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_structured_summaries(state: State):
|
def generate_structured_summaries(state: State):
|
||||||
|
|||||||
Reference in New Issue
Block a user