oracle.trainer module
Module for training hierarchical models in the ORACLE framework.
- class oracle.trainer.EarlyStopper(patience=1, min_delta=0)
Bases:
objectEarlyStopper class for monitoring validation loss and triggering early stopping during training.
- patience
Number of consecutive epochs with insufficient improvement allowed before stopping early.
- Type:
int
- min_delta
Minimum change in validation loss to be considered an improvement.
- Type:
float
- counter
Counts the number of consecutive epochs without sufficient improvement.
- Type:
int
- min_validation_loss
The lowest validation loss observed so far.
- Type:
float
- early_stop(validation_loss)
Checks if training should be stopped early based on the current validation loss. This method compares the provided validation loss against the minimum validation loss seen so far.
If the validation loss is less than the minimum, it updates the minimum value and resets the counter.
If the validation loss exceeds the minimum by more than a specified delta, the counter increments.
When the counter reaches or exceeds the patience threshold, the method indicates that early stopping is warranted.
- Parameters:
validation_loss (float) – The current validation loss from the evaluation phase.
- Returns:
True if early stopping criterion is met (i.e., the counter has reached the patience limit), otherwise False.
- Return type:
bool
- class oracle.trainer.Trainer
Bases:
objectTop-level class providing training functionalities for hierarchical classification models.
- fit(train_loader, val_loader, num_epochs=5)
Train the model for a specified number of epochs.
This method moves the model to the designated device and iterates over the given number of epochs. During each epoch, it trains the model on the training data and evaluates it on the validation data. It records the training loss, validation loss, and macro F1 score, and saves the best models based on the lowest validation loss and highest F1 score. Additionally, the method logs metrics to an external service (e.g., Weights and Biases), updates the learning rate scheduler, saves the loss histories, and stops training early if an early stopping condition is met.
- Parameters:
train_loader (DataLoader) – DataLoader providing batches of training data.
val_loader (DataLoader) – DataLoader providing batches of validation data.
num_epochs (int, optional) – Number of epochs to train for. Defaults to 5.
- Returns:
None
- log_data_in_wandb(train_loss_history, val_loss_history, f1_history, cf)
Logs training and validation metrics to Weights and Biases (wandb).
- Parameters:
train_loss_history (List[float]) – List of training loss values recorded per epoch.
val_loss_history (List[float]) – List of validation loss values recorded per epoch.
f1_history (List[float]) – List of f1 score values recorded per epoch.
cf (Any) – A configuration parameter or additional information to log.
- Returns:
None
- save_loss_history(train_loss_history, val_loss_history, f1_history)
Saves the loss and F1 score histories as NumPy binary files in the model directory.
- Parameters:
train_loss_history (list or array-like) – A collection of training loss values.
val_loss_history (list or array-like) – A collection of validation loss values.
f1_history (list or array-like) – A collection of F1 score values.
Each history is converted to a NumPy array before saving.
- save_model_in_wandb()
Saves training artifacts to Weights & Biases (wandb) for experiment tracking. Each file is expected to reside in the directory specified by self.model_dir.
- setup_training(alpha, gamma, lr, train_labels, val_labels, model_dir, device, wandb_run)
Set up the training components for the model. This method configures the training environment by initializing both the training and validation loss criteria, setting up the optimizer, scheduling learning rate adjustments, and configuring early stopping. It also assigns various parameters such as the device to use, the model directory for saving checkpoints, and the Weights & Biases run instance.
- Parameters:
alpha (float) – Hyperparameter for the training loss function to adjust the influence of the taxonomy.
gamma (float) – Hyperparameter for the loss function that modulates the weighting of different classes.
lr (float) – Learning rate for the optimizer.
train_labels (iterable) – Labels for the training dataset, used to compute class weights in the training loss.
val_labels (iterable) – Labels for the validation dataset, used to compute class weights in the validation loss.
model_dir (str) – Directory path where model checkpoints and related outputs will be stored.
device (torch.device or str) – The computing device (CPU or GPU) on which to run the model.
wandb_run – Instance of the Weights & Biases run for logging training progress.
- Returns:
None
- train_one_epoch(train_loader)
Train the model for one epoch.
This method performs a full training cycle over all batches provided by the train_loader. For each batch, it moves the batch data to the appropriate device, converts labels into a hierarchical one-hot encoding using the taxonomy, performs a forward pass to compute logits, calculates the loss using the train criterion, and applies backpropagation along with an optimizer step.
- Parameters:
train_loader (iterable) – A data loader that yields batches of training data. Each batch is expected to be a dictionary where the key ‘label’ is used to obtain the correct one-hot encoding.
- Returns:
The mean loss value computed over all batches during the epoch.
- Return type:
float
- validate_one_epoch(val_loader)
Performs validation for one epoch.
This method sets the model to evaluation mode and iterates over the validation data loader, performing forward passes without gradient computation. It computes the loss using a specified validation criterion and aggregates the true and predicted labels for metric calculation. Additionally, it constructs a confusion matrix visualization and computes the macro F1 score.
- Parameters:
val_loader (torch.utils.data.DataLoader) – A data loader for the validation dataset. Each batch should be a dictionary with a ‘label’ key among others and may include tensors that need to be moved to the device.
- Returns:
- A dictionary containing:
’val_loss’ (float): The average loss computed over all validation batches.
’macro_f1’ (float): The macro F1 score calculated using aggregated true and predicted labels.
’cf’ (wandb.Image): A wandb.Image object representing the confusion matrix plot.
- Return type:
dict
Note
The method uses the taxonomy provided by self.taxonomy to filter and encode the labels hierarchically.
Leaf node indices are determined using the taxonomy’s level order traversal.
All computations are performed with gradients disabled using torch.no_grad().