Skip to content

Training API

Under the hood, EDS-NLP uses PyTorch to train and run deep-learning models. EDS-NLP acts as a sidekick to PyTorch, providing a set of tools to perform preprocessing, composition and evaluation. The trainable TorchComponents are actually PyTorch modules with a few extra methods to handle the feature preprocessing and postprocessing. Therefore, EDS-NLP is fully compatible with the PyTorch ecosystem.

To build and train a deep learning model, you can either build a training script from scratch (check out the Make a training script tutorial), or use the provided training API. The training API is designed to be flexible and can handle various types of models, including Named Entity Recognition (NER) models, span classifiers, and more. However, if you need more control over the training process, consider writing your own training script.

EDS-NLP supports training models either from the command line or from a Python script or notebook, and switching between the two is relatively straightforward thanks to the use of Confit.

A word about Confit

EDS-NLP makes heavy use of Confit, a configuration library that allows you call functions from Python or the CLI, and validate and optionally cast their arguments.

The EDS-NLP function described on this page is the train function of the edsnlp.train module. When passing a dict to a type-hinted argument (either from a config.yml file, or by calling the function in Python), Confit will instantiate the correct class with the arguments provided in the dict. For instance, we pass a dict to the train_data parameter, which is actually type hinted as a TrainingData: this dict will actually be used as keyword arguments to instantiate this TrainingData object. You can also instantiate a TrainingData object directly and pass it to the function.

You can also tell Confit specifically which class you want to instantiate by using the @register_name = "name_of_the_registered_class" key and value in a dict or config section. We make a heavy use of this mechanism to build pipeline architectures.

How it works

To train a model with EDS-NLP, you need the following ingredients:

  • Pipeline: a pipeline with at least one trainable component. Components that share parameters or that must be updated together are trained in the same phase.

  • Training streams: one or more streams of documents wrapped in a TrainingData object. Each of these specifies how to shuffle the stream, how to batch it with a stat expression such as 2000 words or 16 spans, whether to split batches into sub batches for gradient accumulation, and which components it feeds.

  • Validation streams: optional streams of documents used for periodic evaluation.

  • Scorer: a scorer that defines the metrics to compute on the validation set. By default, it reports speed and uses autocast during scoring unless disabled.

  • Optimizer: an optimizer. Defaults to AdamW with linear warmup and two groups of parameters, one for the transformer with lr 5•10^-5, and one for the rest of the model with lr 3•10^-4.

  • A bunch of hyperparameters: finally, the function expects various hyperparameters (most of them set to sensible defaults) to the function, such as max_steps, seed, validation_interval, checkpoint_interval, grad_max_norm, and more.

The training then proceeds in several steps:

Setup The function prepares the device with Accelerate, creates the output folders, materializes the validation set from the user-provided stream, and runs a post-initialization pass on the training data when requested. This post_init op let's the pipeline inspect the data before learning to adjust the number of heads depending on the labels encountered. Finally, the optimizer is instantiated.

Phases Training runs by phases. A phase groups components that should be optimized together because they share parameters (think for instance of a BERT shared between multiple models). During a phase, losses are computed for each of these "active" components at each step, and only their parameters are updated.

Data preparation Each TrainingData object turns its streams of documents into device ready batches. It optionally shuffles the stream, preprocess the documents for the active components, builds stat-aware batches (for instance, limiting the number of tokens per batch), optionally splits batches into sub batches for gradient accumulation, then converts everything into device-ready tensors. This can be done in parallel to the actual deep-learning work.

Optimization For every training step the function draws one batch from each training stream (in case there are more than one) and synchronizes statistics across processes (in case we're doing multi-GPU training) to keep supports and losses consistent. It runs forward passes for the phase components. When several components reuse the same intermediate features a cache avoids recomputation. Gradients are accumulated over sub batches.

Gradient safety Gradients are always clipped to grad_max_norm. Optionally the function tracks an exponential moving mean and variance of the gradient norm. If a spike is detected you can clip to the running mean or to a threshold or skip the update depending on grad_dev_policy. This protects training from rare extreme updates.

Validation and logging At regular intervals the scorer evaluates the pipeline on the validation documents. It isolates each task by copying docs and disabling unrelated pipes to avoid leakage. It reports throughput and metrics for NER and span attribute classifiers plus any custom metrics.

Checkpoints and output The model is saved on schedule and at the end in output_dir/model-last unless saving is disabled.

Tutorials and examples

Parameters of edsnlp.train

Here are the parameters you can pass to the train function:

PARAMETER DESCRIPTION
nlp

The pipeline that will be trained in place.

TYPE: Pipeline

train_data

The training data. Can be a single TrainingData object, a dict that will be cast or a list of these objects.

TrainingData object/dictionary
PARAMETER DESCRIPTION
data

The stream of documents to train on. The documents will be preprocessed and collated according to the pipeline's components.

TYPE: Stream

batch_size

The batch size. Can be a batching expression like "2000 words", an int (number of documents), or a tuple (batch_size, batch_by). The batch_by argument should be a statistic produced by the pipes that will be trained. For instance, the eds.span_pooler component produces a "spans" statistic, that can be used to produce batches of no more than 16 spans by setting batch_size to "16 spans".

TYPE: BatchSizeArg

shuffle

The shuffle strategy. Can be "dataset" to shuffle the entire dataset (this can be memory-intensive for large file based datasets), "fragment" to shuffle the fragment-based datasets like parquet files, or a batching expression like "2000 words" to shuffle the dataset in chunks of 2000 words.

TYPE: Union[str, Literal[False]]

sub_batch_size

How to split each batch into sub-batches that will be fed to the model independently to accumulate gradients over. To split a batch of 8000 tokens into smaller batches of 1000 tokens each, just set this to "1000 tokens".

You can also request a number of splits, like "4 splits", to split the batch into N parts each close to (but less than) batch_size / N.

TYPE: Optional[BatchSizeArg] DEFAULT: None

pipe_names

The names of the pipes that should be trained on this data. If None, defaults to all trainable pipes.

TYPE: Optional[AsList[str]] DEFAULT: None

post_init

Whether to call the pipeline's post_init method with the data before training.

TYPE: bool DEFAULT: True

TYPE: AsList[TrainingData]

val_data

The validation data. Can be a single Stream object or a list of Stream.

TYPE: AsList[Stream] DEFAULT: []

seed

The random seed

TYPE: int DEFAULT: 42

max_steps

The maximum number of training steps

TYPE: int DEFAULT: 1000

optimizer

The optimizer. If None, a default optimizer will be used.

ScheduledOptimizer object/dictionary
PARAMETER DESCRIPTION
optim

The optimizer to use. If a string (like "adamw") or a type to instantiate, the module and groups must be provided.

TYPE: Union[str, Type[Optimizer], Optimizer]

module

The module to optimize. Usually the nlp pipeline object.

TYPE: Optional[Union[PipelineProtocol, Module]] DEFAULT: None

total_steps

The total number of steps, used for schedules.

TYPE: Optional[int] DEFAULT: None

groups

The groups to optimize. Each group is a dictionary containing:

  • a regex selector key to match the parameter of that group by their names (as listed by nlp.named_parameters())
  • and several other keys that define the optimizer parameters for that group, such as lr, weight_decay etc. The value for these keys can be a Schedule instance or a simple value
  • an exclude key that can be set to True to exclude parameters

The matching is performed by running regex.search(selector, name) so you do not have to match the full name. Note that the order of the groups matters. If a parameter name matches multiple selectors, the configurations of these selectors are combined in reverse order (from the last matched selector to the first), allowing later selectors to complete options from earlier ones. If a selector contains exclude=True, any parameter matching it is excluded from optimization.

TYPE: Optional[List[Group]] DEFAULT: None

TYPE: Union[Draft[ScheduledOptimizer], ScheduledOptimizer, Optimizer] DEFAULT: None

validation_interval

The number of steps between each evaluation. Defaults to 1/10 of max_steps

TYPE: Optional[int] DEFAULT: None

checkpoint_interval

The number of steps between each model save. Defaults to validation_interval

TYPE: Optional[int] DEFAULT: None

grad_max_norm

The maximum gradient norm

TYPE: float DEFAULT: 5.0

grad_dev_policy

The policy to apply when a gradient spike is detected, ie. when the gradient norm is higher than the mean + std * grad_max_dev. Can be:

  • "clip_mean": clip the gradients to the mean gradient norm
  • "clip_threshold": clip the gradients to the mean + std * grad_max_dev
  • "skip": skip the step

These do not apply to grad_max_norm that is always enforced when it is not None, since grad_max_norm is not adaptive and would most likely prohibit the model from learning during the early stages of training when gradients are expected to be high.

TYPE: Optional[Literal['clip_mean', 'clip_threshold', 'skip']] DEFAULT: None

grad_ewm_window

Approximately how many steps should we look back to compute the average gradient norm and variance to detect gradient deviation spikes.

TYPE: int DEFAULT: 100

grad_max_dev

The threshold to apply to detect gradient spikes. A spike is detected when the value is higher than the mean + variance * threshold.

TYPE: float DEFAULT: 7.0

loss_scales

The loss scales for each component (useful for multi-task learning)

TYPE: Dict[str, float] DEFAULT: {}

scorer

How to score the model. Expects a GenericScorer object or a dict containing a mapping of metric names to metric objects.

GenericScorer object/dictionary
PARAMETER DESCRIPTION
batch_size

The batch size to use for scoring. Can be an int (number of documents) or a string (batching expression like "2000 words").

TYPE: Union[int, str] DEFAULT: 1

speed

Whether to compute the model speed (words/documents per second)

TYPE: bool DEFAULT: True

autocast

Whether to use autocasting for mixed precision during the evaluation, defaults to True.

TYPE: Union[bool, Any] DEFAULT: None

metrics

A keyword arguments mapping of metric names to metrics objects. See the metrics documentation for more info.

DEFAULT: {}

TYPE: GenericScorer DEFAULT: GenericScorer()

num_workers

The number of workers to use for preprocessing the data in parallel. Setting it to 0 means no parallelization : data is processed on the main thread which may induce latency slow down the training. To avoid this, a good practice consist in doing the preprocessing either before training or in parallel in a separate process. Because of how EDS-NLP handles stream multiprocessing, changing this value will affect the order of the documents in the produces batches. A stream [1, 2, 3, 4, 5, 6] split in batches of size 3 will produce:

  • [1, 2, 3] and [4, 5, 6] with 1 worker
  • [1, 3, 5] and [2, 4, 6] with 2 workers

TYPE: int DEFAULT: 0

cpu

Whether to use force training on CPU. On MacOS, this might be necessary to get around some mps backend issues.

TYPE: bool DEFAULT: False

mixed_precision

The mixed precision mode. Can be "no", "fp16", "bf16" or "fp8".

TYPE: Literal['no', 'fp16', 'bf16', 'fp8'] DEFAULT: 'no'

output_dir

The output directory, which will contain a model-last directory with the last model, and a train_metrics.json file with the training metrics and stats.

TYPE: Union[Path, str] DEFAULT: Path('artifacts')

output_model_dir

The directory where to save the model. If None, defaults to output_dir / "model-last".

TYPE: Optional[Union[Path, str]] DEFAULT: None

save_model

Whether to save the model or not. This can be useful if you are only interested in the metrics, but no the model, and want to avoid spending time dumping the model weights to the disk.

TYPE: bool DEFAULT: True

logger

The logger to use. Can be a boolean to use the default loggers (rich and json), a list of logger names, or a list of logger objects.

You can use huggingface accelerate integrated loggers (tensorboard, wandb, comet_ml, aim, mlflow, clearml, dvclive), or EDS-NLP simple loggers, or a combination of both:

  • csv: logs to a CSV file in output_dir (artifacts/metrics.csv)
  • json: logs to a JSON file in output_dir (artifacts/metrics.json)
  • rich: logs to a rich table in the terminal

TYPE: Union[bool, AsList[Union[str, GeneralTracker, Draft[GeneralTracker]]]] DEFAULT: True

log_weight_grads

Whether to log the weight gradients during training.

TYPE: bool DEFAULT: False

on_validation_callback

A callback function invoked during validation steps to handle custom logic.

TYPE: Optional[Callable[[Dict], None]] DEFAULT: None

project_name

The project name, used to group experiments in some loggers. If None, defaults to the path of the config file, relative to the home directory, with slashes replaced by double underscores.

TYPE: str DEFAULT: None

kwargs

Additional keyword arguments.

DEFAULT: {}