HFTrainer
Trains a new Hugging Face Transformer model using the Trainer framework.
Example
The following shows a simple example using this pipeline.
import pandas as pd
from datasets import load_dataset
from txtai.pipeline import HFTrainer
trainer = HFTrainer()
# Pandas DataFrame
df = pd.read_csv("training.csv")
model, tokenizer = trainer("bert-base-uncased", df)
# Hugging Face dataset
ds = load_dataset("glue", "sst2")
model, tokenizer = trainer("bert-base-uncased", ds["train"], columns=("sentence", "label"))
# List of dicts
dt = [{"text": "sentence 1", "label": 0}, {"text": "sentence 2", "label": 1}]]
model, tokenizer = trainer("bert-base-uncased", dt)
# Support additional TrainingArguments
model, tokenizer = trainer("bert-base-uncased", dt,
learning_rate=3e-5, num_train_epochs=5)
All TrainingArguments are supported as function arguments to the trainer call.
See the links below for more detailed examples.
Notebook | Description | |
---|---|---|
Train a text labeler | Build text sequence classification models | |
Train without labels | Use zero-shot classifiers to train new models | |
Train a QA model | Build and fine-tune question-answering models | |
Train a language model from scratch | Build new language models |
Training tasks
The HFTrainer pipeline builds and/or fine-tunes models for following training tasks.
Task | Description |
---|---|
language-generation | Causal language model for text generation (e.g. GPT) |
language-modeling | Masked language model for general tasks (e.g. BERT) |
question-answering | Extractive question-answering model, typically with the SQuAD dataset |
sequence-sequence | Sequence-Sequence model (e.g. T5) |
text-classification | Classify text with a set of labels |
token-detection | ELECTRA-style pre-training with replaced token detection |
PEFT
Parameter-Efficient Fine-Tuning (PEFT) is supported through Hugging Face’s PEFT library. Quantization is provided through bitsandbytes. See the examples below.
from txtai.pipeline import HFTrainer
trainer = HFTrainer()
trainer(..., quantize=True, lora=True)
When these parameters are set to True, they use default configuration. This can also be customized.
quantize = {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": "bfloat16"
}
lora = {
"r": 16,
"lora_alpha": 8,
"target_modules": "all-linear",
"lora_dropout": 0.05,
"bias": "none"
}
trainer(..., quantize=quantize, lora=lora)
The parameters also accept transformers.BitsAndBytesConfig
and peft.LoraConfig
instances.
See the following PEFT documentation links for more information.
Methods
Python documentation for the pipeline.
__call__(base, train, validation=None, columns=None, maxlength=None, stride=128, task='text-classification', prefix=None, metrics=None, tokenizers=None, checkpoint=None, quantize=None, lora=None, **args)
Builds a new model using arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
base | path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple | required | |
train | training data | required | |
validation | validation data | None | |
columns | tuple of columns to use for text/label, defaults to (text, None, label) | None | |
maxlength | maximum sequence length, defaults to tokenizer.model_max_length | None | |
stride | chunk size for splitting data for QA tasks | 128 | |
task | optional model task or category, determines the model type, defaults to “text-classification” | ‘text-classification’ | |
prefix | optional source prefix | None | |
metrics | optional function that computes and returns a dict of evaluation metrics | None | |
tokenizers | optional number of concurrent tokenizers, defaults to None | None | |
checkpoint | optional resume from checkpoint flag or path to checkpoint directory, defaults to None | None | |
quantize | quantization configuration to pass to base model | None | |
lora | lora configuration to pass to PEFT model | None | |
args | training arguments | {} |
Returns:
Type | Description |
---|---|
(model, tokenizer) |
Source code in txtai/pipeline/train/hftrainer.py
|
|