Integrate TiDB Vector Search with SQLAlchemy
This tutorial walks you through how to use SQLAlchemy to interact with TiDB Vector Search, store embeddings, and perform vector search queries.
Prerequisites
To complete this tutorial, you need:
- Python 3.8 or higher installed.
- Git installed.
- A TiDB Cloud Serverless cluster. Follow creating a TiDB Cloud Serverless cluster to create your own TiDB Cloud cluster if you don't have one.
Run the sample app
You can quickly learn about how to integrate TiDB Vector Search with SQLAlchemy by following the steps below.
Step 1. Clone the repository
Clone the tidb-vector-python
repository to your local machine:
git clone https://github.com/pingcap/tidb-vector-python.git
Step 2. Create a virtual environment
Create a virtual environment for your project:
cd tidb-vector-python/examples/orm-sqlalchemy-quickstart
python3 -m venv .venv
source .venv/bin/activate
Step 3. Install the required dependencies
Install the required dependencies for the demo project:
pip install -r requirements.txt
Alternatively, you can install the following packages for your project:
pip install pymysql python-dotenv sqlalchemy tidb-vector
Step 4. Configure the environment variables
Navigate to the Clusters page, and then click the name of your target cluster to go to its overview page.
Click Connect in the upper-right corner. A connection dialog is displayed.
Ensure the configurations in the connection dialog match your environment.
- Connection Type is set to
Public
. - Branch is set to
main
. - Connect With is set to
SQLAlchemy
. - Operating System matches your environment.
- Connection Type is set to
Click the PyMySQL tab and copy the connection string.
In the root directory of your Python project, create a
.env
file and paste the connection string into it.The following is an example for macOS:
TIDB_DATABASE_URL="mysql+pymysql://<prefix>.root:<password>@gateway01.<region>.prod.aws.tidbcloud.com:4000/test?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true"
Step 5. Run the demo
python sqlalchemy-quickstart.py
Example output:
Get 3-nearest neighbor documents:
- distance: 0.00853986601633272
document: fish
- distance: 0.12712843905603044
document: dog
- distance: 0.7327387580875756
document: tree
Get documents within a certain distance:
- distance: 0.00853986601633272
document: fish
- distance: 0.12712843905603044
document: dog
Sample code snippets
You can refer to the following sample code snippets to develop your application.
Create vector tables
Connect to TiDB cluster
import os
import dotenv
from sqlalchemy import Column, Integer, create_engine, Text
from sqlalchemy.orm import declarative_base, Session
from tidb_vector.sqlalchemy import VectorType
dotenv.load_dotenv()
tidb_connection_string = os.environ['TIDB_DATABASE_URL']
engine = create_engine(tidb_connection_string)
Define a vector column
Create a table with a column named embedding
that stores a 3-dimensional vector.
Base = declarative_base()
class Document(Base):
__tablename__ = 'sqlalchemy_demo_documents'
id = Column(Integer, primary_key=True)
content = Column(Text)
embedding = Column(VectorType(3))
Store documents with embeddings
with Session(engine) as session:
session.add(Document(content="dog", embedding=[1, 2, 1]))
session.add(Document(content="fish", embedding=[1, 2, 4]))
session.add(Document(content="tree", embedding=[1, 0, 0]))
session.commit()
Search the nearest neighbor documents
Search for the top-3 documents that are semantically closest to the query vector [1, 2, 3]
based on the cosine distance function.
with Session(engine) as session:
distance = Document.embedding.cosine_distance([1, 2, 3]).label('distance')
results = session.query(
Document, distance
).order_by(distance).limit(3).all()
Search documents within a certain distance
Search for documents whose cosine distance from the query vector [1, 2, 3]
is less than 0.2.
with Session(engine) as session:
distance = Document.embedding.cosine_distance([1, 2, 3]).label('distance')
results = session.query(
Document, distance
).filter(distance < 0.2).order_by(distance).limit(3).all()