49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
import json
|
|
from collections.abc import Iterator
|
|
|
|
from langchain.schema import BaseStore, Document
|
|
from langchain.storage import LocalFileStore
|
|
|
|
|
|
class DocumentLocalFileStore(BaseStore[str, Document]):
|
|
def __init__(self, root_path: str):
|
|
self.store = LocalFileStore(root_path)
|
|
|
|
def mget(self, keys):
|
|
"""Get multiple documents by keys"""
|
|
results = []
|
|
for key in keys:
|
|
try:
|
|
serialized = self.store.mget([key])[0]
|
|
if serialized:
|
|
# Deserialize the document
|
|
doc_dict = json.loads(serialized.decode("utf-8"))
|
|
doc = Document(
|
|
page_content=doc_dict["page_content"],
|
|
metadata=doc_dict.get("metadata", {}),
|
|
)
|
|
results.append(doc)
|
|
else:
|
|
results.append(None)
|
|
except:
|
|
results.append(None)
|
|
return results
|
|
|
|
def mset(self, key_value_pairs):
|
|
"""Set multiple key-value pairs"""
|
|
serialized_pairs = []
|
|
for key, doc in key_value_pairs:
|
|
# Serialize the document
|
|
doc_dict = {"page_content": doc.page_content, "metadata": doc.metadata}
|
|
serialized = json.dumps(doc_dict).encode("utf-8")
|
|
serialized_pairs.append((key, serialized))
|
|
|
|
self.store.mset(serialized_pairs)
|
|
|
|
def mdelete(self, keys):
|
|
"""Delete multiple keys"""
|
|
self.store.mdelete(keys)
|
|
|
|
def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]:
|
|
return self.store.yield_keys(prefix=prefix)
|