Selectively Querying Documents in LlamaIndex
Running RAG over a Subset of your Embedded Corpus
When working with LlamaIndex and other Retrieval-Augmented Generation (RAG) systems, most tutorials focus on ingesting and querying a single document. Usually they do this to show you how retrieving content from a high-dimensional vector space (embeddings) works. You typically read the document from a source, parse it, embed it, and store it in your vector store. Once there, querying is straightforward. But what if you have many documents and want to selectively query against only a subset, such as Document #2 (doc_id=2), from your vector store or all documents related to "technology" but not "finance"?
This article demonstrates how to use LlamaIndex to create a filtered query engine, enabling you to query specific nodes based on custom metadata. My approach offers a more structured and efficient way to retrieve relevant information, simplifying the management and scaling of your querying process.
MetadataFilters: Customizing Your Query Scope
To start, when processing your incoming data, you need to associate metadata with each node during ingestion. This metadata can include any relevant information you need, such as document IDs, categories, or other attributes. By attaching metadata, you can later use MetadataFilters to narrow down the scope of your query. Attaching metadata in LlamaIndex is quite easy. When using the reader, just pass in the additional metadata fields, here "file_metadata":
# Read all HTML files in the directory using Unstructured
# To each file, add the company_id we fetched from the DB for this document.
documents = SimpleDirectoryReader(
input_dir=source_directory,
file_extractor={".html": UnstructuredReader()},
file_metadata=lambda x: {"company_id": int(company_id)}, # <-- addl. metadata
required_exts=[".html"],
recursive=True,
).load_data()
Now, each document knows its "company_id". And, I'd like to be able to ask "What is the company's name?" and present the LLM with only context from a particular company, not all of them.
The goal now is to create a query engine that can dynamically filter nodes at query time based on metadata, rather than predefining the filter criteria when setting up the query engine. This approach offers more flexibility and allows you to tailor the query to specific needs without rebuilding the query engine. Let's build out an example to prove the concept.
Example: Querying by company_id
Imagine you have a collection of webpages from various companies, each identified by a unique company_id. You want to query these documents to find out, for instance, "What products does {company_id} sell?" while setting the company_id at query time. If you don't filter you'll get all products from all companies.
Updated Code for Creating a Filtered Query Engine
Here's the simplified code that encapsulates the logic for creating a filtered query engine. In my setup, I'm using pgvector as my vector store, hence the pg connection strings.
import logging
import os
from urllib.parse import urlparse
from llama_index.core import VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.settings import Settings
from llama_index.core.vector_stores import FilterOperator, MetadataFilter, MetadataFilters
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.postgres import PGVectorStore
def init_vector_store():
"""
Creates a vector store from the postgres connection string.
"""
original_conn_string = os.getenv("PG_CONNECTION_STRING")
if not original_conn_string:
raise ValueError("PG_CONNECTION_STRING environment variable is not set.")
conn_string = original_conn_string.replace(
urlparse(original_conn_string).scheme + "://", "postgresql+psycopg2://"
)
async_conn_string = original_conn_string.replace(
urlparse(original_conn_string).scheme + "://", "postgresql+asyncpg://"
)
vector_store = PGVectorStore(
connection_string=conn_string,
async_connection_string=async_conn_string,
schema_name=os.getenv("PGVECTOR_SCHEMA", "public"),
table_name=os.getenv("PGVECTOR_TABLE", "company_embeddings"),
embed_dim=int(os.getenv("EMBEDDING_DIM", 1024)),
)
logging.debug("Initialized vector store.")
return vector_store
def init_embed_model(use_globally=True) -> BaseEmbedding:
"""
Creates an embedding model from the HuggingFace library.
"""
config = {"model_name": os.getenv("EMBEDDING_MODEL")}
embed_model = HuggingFaceEmbedding(**config, trust_remote_code=True)
if use_globally:
Settings.embed_model = embed_model
logging.info(f"Initialized embed model: {config}")
return embed_model
def create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10):
"""
Creates a query engine that filters results based on a company_id.
"""
# create the filters to find nodes matching the company_id
filters = MetadataFilters(
filters=[
MetadataFilter(
key="company_id",
value=company_id,
operator=FilterOperator.EQ,
)
]
)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store, embed_model=embed_model)
# here the retriever is scoped to fetch ONLY documents from the filters.
vector_retriever = index.as_retriever(similarity_top_k=top_k, filters=filters, top_k=top_k)
query_engine = RetrieverQueryEngine(retriever=vector_retriever)
return query_engine
vector_store = init_vector_store()
embed_model = init_embed_model()
# query for company 65
company_id = 65
query_engine = create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10)
response = query_engine.query("What is the company's name?")
print(response.response)
# query for company 114.
company_id = 114
query_engine = create_filtered_query_engine(company_id, vector_store, embed_model, top_k=10)
response = query_engine.query("What is the company's name?")
print(response.response)
Of course, you could take this a step further and wrap all of that in a single call passing in the parameter you seek: response = query_engine.query("What is the company's name?", company_id=114), but that exercise is left to the reader. :)
This updated approach simplifies the creation of a filtered query engine, allowing you to dynamically filter results based on metadata at query time, this scoping your research to only a subset of the content as desired. The new code is more streamlined and easier to understand, making it more efficient for selective querying.
By encapsulating the logic for filtered querying, you not only simplify the process but also make your codebase more maintainable and adaptable to future needs. This approach allows you to dynamically filter at query time, providing a flexible and efficient way to interact with your data. Happy querying!
Happy coding!