oracle.presets module

Top-level module for defining presets and utility functions for model selection, data loading, and training configurations in the ORACLE project.

oracle.presets.get_class_weights(labels)

Calculates the weights for each class using inverse frequency weighting.

Parameters:

labels (array-like) – An iterable of labels from which unique classes and their counts are derived.

Returns:

A dictionary where keys are the unique class labels and values are their corresponding weights (1/count).

Return type:

dict

oracle.presets.get_model(model_choice)

Retrieves and instantiates a model based on the provided model choice.

Parameters:

model_choice (str) –

A string identifier for the desired model configuration. Valid options and their corresponding behaviors are:

  • ”BTS-lite”: Uses BTS_Taxonomy to instantiate a GRU model.

  • ”BTS”: Uses BTS_Taxonomy to instantiate a GRU_MD model with a static feature dimension of 30.

  • ”ZTF_Sims-lite”: Uses BTS_Taxonomy to instantiate a GRU model.

  • ”ELAsTiCC-lite”: Uses ORACLE_Taxonomy to instantiate a GRU model.

  • ”ELAsTiCC”: Uses ORACLE_Taxonomy to instantiate a GRU_MD model with a static feature dimension of 18.

Returns:

The model instance associated with the given model_choice.

oracle.presets.get_test_loaders(model_choice, batch_size, max_n_per_class, days_list, excluded_classes=[], mapper=None)

Generates and returns a list of test DataLoaders configured according to the specified model type and parameters. :param model_choice: Specifies which model and corresponding dataset to use. Supported values include

“BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, and “ELAsTiCC”.

Parameters:
  • batch_size (int) – The size of batches to generate from each DataLoader.

  • max_n_per_class (int) – The maximum number of samples to include per class in the dataset.

  • days_list (list) – A list of day values used to dynamically configure a transformation (truncating light curves) applied to the test dataset.

  • excluded_classes (list, optional) – A list of class labels to exclude from the dataset. Defaults to an empty list.

  • mapper (dict, optional) – An optional dictionary used to map or modify dataset labels/structure (only used for the “BTS” model). Defaults to None.

Returns:

A list of DataLoader objects, each configured with a custom transformation based on a corresponding day from days_list.

Return type:

List[DataLoader]

Note

Each DataLoader is constructed for a specific truncation of the light curve based on the day value. The generator for DataLoader shuffling is explicitly created on the CPU.

oracle.presets.get_train_loader(model_choice, batch_size, max_n_per_class, excluded_classes=[])

Generates a DataLoader for training based on the provided model choice and dataset configuration.

Parameters:
  • model_choice (str) – The identifier for the model and corresponding dataset. Supported values include: “BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, and “ELAsTiCC”.

  • batch_size (int) – The number of samples per batch to load.

  • max_n_per_class (int) – The maximum number of samples to include per class in the dataset.

  • excluded_classes (list, optional) – A list of class identifiers to exclude from the dataset. Defaults to an empty list.

Returns:

A tuple containing:
  • DataLoader(torch.utils.data.DataLoader): A DataLoader instance for iterating through the training dataset.

  • list(list): A list of all labels present in the training dataset.

Return type:

tuple

oracle.presets.get_val_loader(model_choice, batch_size, val_truncation_days, max_n_per_class, excluded_classes=[])

Creates a DataLoader for the validation dataset along with its corresponding labels based on the specified model choice.

This function selects a dataset and transformation based on the model_choice provided and then concatenates the datasets (one for each truncation day specified in val_truncation_days). It returns a DataLoader constructed with the concatenated dataset and the set of all labels retrieved from the first dataset instance.

Parameters:
  • model_choice (str) – The name of the model variant to use. Valid options include: “BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, “ELAsTiCC”.

  • batch_size (int) – The number of samples per batch in the returned DataLoader.

  • val_truncation_days (list) – A list of days used to truncate the light curves; each day corresponds to a transformation applied to the dataset.

  • max_n_per_class (int) – The maximum number of samples to include per class in the dataset.

  • excluded_classes (list, optional) – A list of classes to be excluded from the dataset. Defaults to an empty list.

Returns:

A tuple containing:
  • DataLoader: The DataLoader for the concatenated validation dataset with the specified batch size and collate function, constructed with a CPU-based torch.Generator.

  • list: A list of validation labels obtained from the first dataset in the list via get_all_labels().

Return type:

tuple

oracle.presets.worker_init_fn(worker_id)

Ensure proper random seeding in each worker process.