Model Training#
Interfaces and support for model training.
- class lenskit.training.TrainingOptions(retrain=True, device=None, rng=None)#
Bases:
object
Options and context settings that govern model training.
- retrain: bool = True#
Whether the model should retrain if it is already trained. If
False
, the model should cleanly skip training if it is already trained.
- device: str | None = None#
The device on which to train (e.g.
'cuda'
). May be ignored if the model does not support the specified device.
- rng: RNGInput = None#
Random number generator to use for any randomness in the training process. This option contains any `SPEC 7`_-compatible random number generator specification; the
random_generator()
will convert that into a NumPyGenerator
.
- random_generator()#
Obtain a random generator from the configured RNG or seed.
Note
Each call to this method will return a fresh generator from the same seed. Components should call it once at the beginning of their training procesess.
- Return type:
- configured_device(*, gpu_default=False)#
Get the configured device, consulting environment variables and defaults if necessary. It looks for a device in the following order:
The
device
, if specified on this object.The
LK_DEVICE
environment variable.If CUDA is enabled and
gpu_default
isTrue
, return “cuda”The CPU.
- class lenskit.training.Trainable(*args, **kwargs)#
Bases:
Protocol
Interface for components and objects that can learn parameters from training data. It supports training and checking if a component has already been trained. The resulting model should be pickleable. Trainable objects are usually also components.
Note
Trainable components must also implement
__call__
.Note
A future LensKit version will add support for extracting model parameters a la Pytorch’s
state_dict
, but this capability was not ready for 2025.1.- Stability:
- Full (see Stability Levels).
- train(data, options)#
Train the model to learn its parameters from a training dataset.
- Parameters:
data (Dataset) – The training dataset.
options (TrainingOptions) – The training options.
- Return type:
None
- class lenskit.training.IterativeTraining(*args, **kwargs)#
-
Base class for components that support iterative training. This both automates the
Trainable.train()
method for iterative training in terms of initialization, epoch, and finalization methods, and exposes those methods to client code that may wish to directly control the iterative training process.- Stability:
- Full (see Stability Levels).
- property expected_training_epochs: int | None#
Get the number of training epochs expected to run. The default implementation looks for an
epochs
attribute on the configuration object (self.config
).
- train(data, options=TrainingOptions(retrain=True, device=None, rng=None))#
Implementation of
Trainable.train()
that uses the training loop. It also uses thetrained_epochs
attribute to detect if the model has already been trained for the purposes of honoringTrainingOptions.retrain
, and updates that attribute as model training progresses.- Parameters:
data (Dataset)
options (TrainingOptions)
- Return type:
None
- abstract training_loop(data, options)#
Training loop implementation, to be supplied by the derived class. This method should return a iterator that, when iterated, will perform each training epoch; when training is complete, it should finalize the model and signal iteration completion.
Each epoch can yield metrics, such as training or validation loss, to be logged with structured logging and can be used by calling code to do other analysis.
See Iterative Training for more details on writing iterative training loops.