oracle.presets module
Top-level module for defining presets and utility functions for model selection, data loading, and training configurations in the ORACLE project.
- oracle.presets.get_class_weights(labels)
Calculates the weights for each class using inverse frequency weighting.
- Parameters:
labels (array-like) – An iterable of labels from which unique classes and their counts are derived.
- Returns:
A dictionary where keys are the unique class labels and values are their corresponding weights (1/count).
- Return type:
dict
- oracle.presets.get_model(model_choice)
Retrieves and instantiates a model based on the provided model choice.
- Parameters:
model_choice (str) –
A string identifier for the desired model configuration. Valid options and their corresponding behaviors are:
”BTS-lite”: Uses BTS_Taxonomy to instantiate a GRU model.
”BTS”: Uses BTS_Taxonomy to instantiate a GRU_MD model with a static feature dimension of 30.
”ZTF_Sims-lite”: Uses BTS_Taxonomy to instantiate a GRU model.
”ELAsTiCC-lite”: Uses ORACLE_Taxonomy to instantiate a GRU model.
”ELAsTiCC”: Uses ORACLE_Taxonomy to instantiate a GRU_MD model with a static feature dimension of 18.
- Returns:
The model instance associated with the given model_choice.
- oracle.presets.get_test_loaders(model_choice, batch_size, max_n_per_class, days_list, excluded_classes=[], mapper=None)
Generates and returns a list of test DataLoaders configured according to the specified model type and parameters. :param model_choice: Specifies which model and corresponding dataset to use. Supported values include
“BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, and “ELAsTiCC”.
- Parameters:
batch_size (int) – The size of batches to generate from each DataLoader.
max_n_per_class (int) – The maximum number of samples to include per class in the dataset.
days_list (list) – A list of day values used to dynamically configure a transformation (truncating light curves) applied to the test dataset.
excluded_classes (list, optional) – A list of class labels to exclude from the dataset. Defaults to an empty list.
mapper (dict, optional) – An optional dictionary used to map or modify dataset labels/structure (only used for the “BTS” model). Defaults to None.
- Returns:
A list of DataLoader objects, each configured with a custom transformation based on a corresponding day from days_list.
- Return type:
List[DataLoader]
Note
Each DataLoader is constructed for a specific truncation of the light curve based on the day value. The generator for DataLoader shuffling is explicitly created on the CPU.
- oracle.presets.get_train_loader(model_choice, batch_size, max_n_per_class, excluded_classes=[])
Generates a DataLoader for training based on the provided model choice and dataset configuration.
- Parameters:
model_choice (str) – The identifier for the model and corresponding dataset. Supported values include: “BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, and “ELAsTiCC”.
batch_size (int) – The number of samples per batch to load.
max_n_per_class (int) – The maximum number of samples to include per class in the dataset.
excluded_classes (list, optional) – A list of class identifiers to exclude from the dataset. Defaults to an empty list.
- Returns:
- A tuple containing:
DataLoader(torch.utils.data.DataLoader): A DataLoader instance for iterating through the training dataset.
list(list): A list of all labels present in the training dataset.
- Return type:
tuple
- oracle.presets.get_val_loader(model_choice, batch_size, val_truncation_days, max_n_per_class, excluded_classes=[])
Creates a DataLoader for the validation dataset along with its corresponding labels based on the specified model choice.
This function selects a dataset and transformation based on the model_choice provided and then concatenates the datasets (one for each truncation day specified in val_truncation_days). It returns a DataLoader constructed with the concatenated dataset and the set of all labels retrieved from the first dataset instance.
- Parameters:
model_choice (str) – The name of the model variant to use. Valid options include: “BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, “ELAsTiCC”.
batch_size (int) – The number of samples per batch in the returned DataLoader.
val_truncation_days (list) – A list of days used to truncate the light curves; each day corresponds to a transformation applied to the dataset.
max_n_per_class (int) – The maximum number of samples to include per class in the dataset.
excluded_classes (list, optional) – A list of classes to be excluded from the dataset. Defaults to an empty list.
- Returns:
- A tuple containing:
DataLoader: The DataLoader for the concatenated validation dataset with the specified batch size and collate function, constructed with a CPU-based torch.Generator.
list: A list of validation labels obtained from the first dataset in the list via get_all_labels().
- Return type:
tuple
- oracle.presets.worker_init_fn(worker_id)
Ensure proper random seeding in each worker process.