oracle.train module
Interface for training models in the ORACLE framework.
- oracle.train.get_wandb_run(args)
Initializes and returns a Weights & Biases (wandb) run with the specified configuration.
- Parameters:
args – An object that must contain the following attributes: 1. num_epochs (int): The number of training epochs. 2. batch_size (int): The batch size to be used. 3. lr (float): The learning rate for training. 4. max_n_per_class (int): The maximum number of samples per class. 5. alpha (float): A hyperparameter used for controlling loss behavior. 6. gamma (float): A hyperparameter used for weighting. 7. dir (str): The directory path where the model should be saved. 8. model (str): The identifier for the chosen model architecture. 9. load_weights (str): The file path for the pretrained model weights, if any.
- Returns:
A wandb run instance initialized with the given configuration, which logs metadata and hyperparameters.
- oracle.train.main()
- oracle.train.parse_args()
Get commandline options
- oracle.train.run_training_loop(args)
Runs the training loop for the model using the specified configuration and dataset loaders.
- This function performs the following steps:
Extracts training configuration parameters (e.g., number of epochs, batch size, learning rate, model choice, etc.) from the args argument.
Initializes the model based on the provided model choice.
Retrieves the training and validation data loaders along with their corresponding labels.
Initializes a logging run (using WandB) and sets up the directory for saving models and training arguments.
Optionally loads a pretrained model’s weights if a valid path is provided.
Moves the model to the appropriate device, sets up the training configuration (including hyperparameters such as alpha and gamma), and begins model training.
After training, saves the model to WandB and finalizes the logging run.
- Parameters:
args (argparse.Namespace) – An object containing all necessary configuration parameters and hyperparameters including: 1. num_epochs (int): Number of epochs to train the model. 2. batch_size (int): Size of the batches used in training and validation. 3. lr (float): Learning rate for the optimizer. 4. max_n_per_class (int): Maximum number of samples per class for the training data. 5. alpha (float): Hyperparameter used during training (specific purpose defined by model’s setup). 6. gamma (float): Hyperparameter used during training (specific purpose defined by model’s setup). 7. dir (str): Directory path for saving the model and other related artifacts. 8. model (str): Identifier to select which model architecture to use. 9. load_weights (str or None): Path to pretrained model weights. If provided, these weights are loaded into the model.
- Returns:
None
- oracle.train.save_args_to_csv(args, filepath)
Save command-line arguments to a CSV file.
This function converts the attributes of an object, typically parsed from command-line input, into a single-row pandas DataFrame, and saves it to a CSV file at the specified filepath.
- Parameters:
args (object) – An object containing attributes to be saved, often created using argparse.
filepath (str) – The file path (including filename) where the CSV file will be written.
- Returns:
None