Creating your own embedding function
from chromadb.api.types import (
Documents,
EmbeddingFunction,
Embeddings
)
class MyCustomEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
my_ef_param: str
):
"""Initialize the embedding function."""
def __call__(self, input: Documents) -> Embeddings:
"""Embed the input documents."""
return self._my_ef(input)
Now let’s break the above down.
First you create a class that inherits from EmbeddingFunction[Documents]
. The Documents
type is a list of Document
objects. Each Document
object has a text
attribute that contains the text of the document. Chroma also supports multi-modal
Example Implementation
Below is an implementation of an embedding function that works with transformers
models.
Note
This example requires the transformers
and torch
python packages. You can install them with pip install transformers torch
.
By default, all transformers
models on HF are supported are also supported by the sentence-transformers
package. For which Chroma provides out of the box support.
import importlib
from typing import Optional, cast
import numpy as np
import numpy.typing as npt
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
class TransformerEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model_name: str = "dbmdz/bert-base-turkish-cased",
cache_dir: Optional[str] = None,
):
try:
from transformers import AutoModel, AutoTokenizer
self._torch = importlib.import_module("torch")
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
except ImportError:
raise ValueError(
"The transformers and/or pytorch python package is not installed. Please install it with "
"`pip install transformers` or `pip install torch`"
)
@staticmethod
def _normalize(vector: npt.NDArray) -> npt.NDArray:
"""Normalizes a vector to unit length using L2 norm."""
norm = np.linalg.norm(vector)
if norm == 0:
return vector
return vector / norm
def __call__(self, input: Documents) -> Embeddings:
inputs = self._tokenizer(
input, padding=True, truncation=True, return_tensors="pt"
)
with self._torch.no_grad():
outputs = self._model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1) # mean pooling
return [e.tolist() for e in self._normalize(embeddings)]
May 18, 2024