Model Training#

Interfaces and support for model training.

class lenskit.training.TrainingOptions(retrain=True, device=None, rng=None)#

Bases: object

Options and context settings that govern model training.

Parameters:
  • retrain (bool)

  • device (str | None)

  • rng (lenskit.random.RNGInput)

retrain: bool = True#

Whether the model should retrain if it is already trained. If False, the model should cleanly skip training if it is already trained.

device: str | None = None#

The device on which to train (e.g. 'cuda'). May be ignored if the model does not support the specified device.

rng: RNGInput = None#

Random number generator to use for any randomness in the training process. This option contains any `SPEC 7`_-compatible random number generator specification; the random_generator() will convert that into a NumPy Generator.

random_generator()#

Obtain a random generator from the configured RNG or seed.

Note

Each call to this method will return a fresh generator from the same seed. Components should call it once at the beginning of their training procesess.

Return type:

Generator

configured_device(*, gpu_default=False)#

Get the configured device, consulting environment variables and defaults if necessary. It looks for a device in the following order:

  1. The device, if specified on this object.

  2. The LK_DEVICE environment variable.

  3. If CUDA is enabled and gpu_default is True, return “cuda”

  4. The CPU.

Parameters:

gpu_default (bool) – Whether a CUDA GPU should be preferred if it is available and no device has been specified.

Return type:

str

class lenskit.training.Trainable(*args, **kwargs)#

Bases: Protocol

Interface for components and objects that can learn parameters from training data. It supports training and checking if a component has already been trained. This protocol only captures the concept of trainability; most trainable components should have other properties and behaviors as well:

  • They are usually components (Component), with an appropriate __call__ method.

  • They should be pickleable.

  • They should also usually implement ParameterContainer, to allow the learned parameters to be serialized and deserialized without pickling.

Stability:
Full (see Stability Levels).
train(data, options)#

Train the model to learn its parameters from a training dataset.

Parameters:
Return type:

None

class lenskit.training.IterativeTraining(*args, **kwargs)#

Bases: ABC, Trainable

Base class for components that support iterative training. This both automates the Trainable.train() method for iterative training in terms of initialization, epoch, and finalization methods, and exposes those methods to client code that may wish to directly control the iterative training process.

Stability:
Full (see Stability Levels).
trained_epochs: int = 0#

The number of epochs for which this model has been trained.

property expected_training_epochs: int | None#

Get the number of training epochs expected to run. The default implementation looks for an epochs attribute on the configuration object (self.config).

train(data, options=TrainingOptions(retrain=True, device=None, rng=None))#

Implementation of Trainable.train() that uses the training loop. It also uses the trained_epochs attribute to detect if the model has already been trained for the purposes of honoring TrainingOptions.retrain, and updates that attribute as model training progresses.

Parameters:
Return type:

None

abstractmethod training_loop(data, options)#

Training loop implementation, to be supplied by the derived class. This method should return a iterator that, when iterated, will perform each training epoch; when training is complete, it should finalize the model and signal iteration completion.

Each epoch can yield metrics, such as training or validation loss, to be logged with structured logging and can be used by calling code to do other analysis.

See Iterative Training for more details on writing iterative training loops.

Parameters:
Return type:

Iterator[dict[str, float] | None]

class lenskit.training.UsesTrainer(config=None, **kwargs)#

Bases: IterativeTraining, Component, ABC

Base class for models that implement Trainable via a ModelTrainer. This class implements IterativeTraining for compatibility, but the IterativeTraining interface is deprecated.

The component’s configuration must have an epochs attribute noting the number of epochs to train.

Parameters:
  • config (Any)

  • kwargs (Any)

train(data, options=TrainingOptions(retrain=True, device=None, rng=None))#

Implementation of Trainable.train() that uses the model trainer.

Parameters:
Return type:

None

training_loop(data, options)#

Training loop implementation, to be supplied by the derived class. This method should return a iterator that, when iterated, will perform each training epoch; when training is complete, it should finalize the model and signal iteration completion.

Each epoch can yield metrics, such as training or validation loss, to be logged with structured logging and can be used by calling code to do other analysis.

See Iterative Training for more details on writing iterative training loops.

Parameters:
abstractmethod create_trainer(data, options)#

Create a model trainer to train this model.

Parameters:
Return type:

ModelTrainer

class lenskit.training.ModelTrainer#

Bases: ABC

Protocol implemented by iterative trainers for models. Models that implement UsesTrainer will return an object implementing this protocol from their create_trainer() method.

This protocol only defines the core aspects of training a model. Trainers should also implement ParameterContainer to allow training to be checkpointed and resumed.

It is also a good idea for the trainer to be pickleable, but the parameter container interface is the primary mechanism for checkpointing.

abstractmethod train_epoch()#

Perform one epoch of the training process, optionally returning metrics on the training behavior. After each training iteration, the mmodel must be usable.

Return type:

dict[str, float] | None

abstractmethod finalize()#

Finish the training process, cleaning up any unneeded data structures and doing any finalization steps to the model.