Creating your own embedding function

  1. from chromadb.api.types import (
  2. Documents,
  3. EmbeddingFunction,
  4. Embeddings
  5. )
  6. class MyCustomEmbeddingFunction(EmbeddingFunction[Documents]):
  7. def __init__(
  8. self,
  9. my_ef_param: str
  10. ):
  11. """Initialize the embedding function."""
  12. def __call__(self, input: Documents) -> Embeddings:
  13. """Embed the input documents."""
  14. return self._my_ef(input)

Now let’s break the above down.

First you create a class that inherits from EmbeddingFunction[Documents]. The Documents type is a list of Document objects. Each Document object has a text attribute that contains the text of the document. Chroma also supports multi-modal

Example Implementation

Below is an implementation of an embedding function that works with transformers models.

Note

This example requires the transformers and torch python packages. You can install them with pip install transformers torch.

By default, all transformers models on HF are supported are also supported by the sentence-transformers package. For which Chroma provides out of the box support.

  1. import importlib
  2. from typing import Optional, cast
  3. import numpy as np
  4. import numpy.typing as npt
  5. from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
  6. class TransformerEmbeddingFunction(EmbeddingFunction[Documents]):
  7. def __init__(
  8. self,
  9. model_name: str = "dbmdz/bert-base-turkish-cased",
  10. cache_dir: Optional[str] = None,
  11. ):
  12. try:
  13. from transformers import AutoModel, AutoTokenizer
  14. self._torch = importlib.import_module("torch")
  15. self._tokenizer = AutoTokenizer.from_pretrained(model_name)
  16. self._model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
  17. except ImportError:
  18. raise ValueError(
  19. "The transformers and/or pytorch python package is not installed. Please install it with "
  20. "`pip install transformers` or `pip install torch`"
  21. )
  22. @staticmethod
  23. def _normalize(vector: npt.NDArray) -> npt.NDArray:
  24. """Normalizes a vector to unit length using L2 norm."""
  25. norm = np.linalg.norm(vector)
  26. if norm == 0:
  27. return vector
  28. return vector / norm
  29. def __call__(self, input: Documents) -> Embeddings:
  30. inputs = self._tokenizer(
  31. input, padding=True, truncation=True, return_tensors="pt"
  32. )
  33. with self._torch.no_grad():
  34. outputs = self._model(**inputs)
  35. embeddings = outputs.last_hidden_state.mean(dim=1) # mean pooling
  36. return [e.tolist() for e in self._normalize(embeddings)]

May 18, 2024