oracle.presets

Top-level module for defining presets and utility functions for model selection, data loading, and training configurations in the ORACLE project.

Functions

get_class_weights(labels)

Calculates the weights for each class using inverse frequency weighting.

get_model(model_choice)

Retrieves and instantiates a model based on the provided model choice.

get_test_loaders(model_choice, batch_size, ...)

Generates and returns a list of test DataLoaders configured according to the specified model type and parameters. :param model_choice: Specifies which model and corresponding dataset to use. Supported values include "BTS-lite", "BTS", "ZTF_Sims-lite", "ELAsTiCC-lite", and "ELAsTiCC". :type model_choice: str :param batch_size: The size of batches to generate from each DataLoader. :type batch_size: int :param max_n_per_class: The maximum number of samples to include per class in the dataset. :type max_n_per_class: int :param days_list: A list of day values used to dynamically configure a transformation (truncating light curves) applied to the test dataset. :type days_list: list :param excluded_classes: A list of class labels to exclude from the dataset. Defaults to an empty list. :type excluded_classes: list, optional :param mapper: An optional dictionary used to map or modify dataset labels/structure (only used for the "BTS" model). Defaults to None. :type mapper: dict, optional.

get_train_loader(model_choice, batch_size, ...)

Generates a DataLoader for training based on the provided model choice and dataset configuration.

get_val_loader(model_choice, batch_size, ...)

Creates a DataLoader for the validation dataset along with its corresponding labels based on the specified model choice.

worker_init_fn(worker_id)

Ensure proper random seeding in each worker process.