oracle package

Subpackages

Submodules

oracle.architectures module

Top-level module for defining various neural network architectures for hierarchical classification.

class oracle.architectures.GRU(taxonomy: Taxonomy, ts_feature_dim=5)

Bases: Hierarchical_classifier

GRU-based neural network architecture for hierarchical classification.

forward(batch)

Compute the logits for a given input batch through the model.

This method processes the input batch by first obtaining its latent space embeddings, applying a ReLU activation, and then projecting the result to logits using a fully connected output layer.

Parameters:

batch (dict) – Input data batch.

Returns:

The computed output logits after the final linear transformation.

Return type:

logits(torch.Tensor)

get_latent_space_embeddings(batch)

Compute latent space embeddings for a batch of time-series data.

Parameters:

batch (dict) – A dictionary containing: 1. ‘ts’ (torch.Tensor): Input time-series data of shape (batch_size, seq_len, n_ts_features). 2.’length’ (torch.Tensor or list[int]): Sequence lengths indicating the valid lengths of each time-series in the batch.

Returns:

The latent space embeddings resulting from the network’s forward pass.

Return type:

torch.Tensor

class oracle.architectures.GRU_MD(taxonomy: Taxonomy, ts_feature_dim=5, static_feature_dim=30)

Bases: Hierarchical_classifier

GRU-based neural network architecture with multi-dimensional static features for hierarchical classification.

forward(batch)

Computes the forward pass through the network module. This method performs the following steps:

  • Extracts latent space embeddings from the input batch using self.get_latent_space_embeddings.

  • Applies a ReLU activation to the latent embeddings.

  • Computes the logits by passing the activated embeddings through a fully connected layer (self.fc_out).

Parameters:

batch – Input data (e.g., tensors or other data structures) used to compute the latent embeddings.

Returns:

Logits produced by the network.

Return type:

Tensor

get_latent_space_embeddings(batch)

Generates the latent space embedding for a given batch of data by processing both time series and static features.

Parameters:

batch (dict) – A dictionary containing the following keys: 1. ‘ts’ (torch.Tensor): Time series data of shape (batch_size, seq_len, n_ts_features). 2. ‘length’ (torch.Tensor): Tensor containing the true lengths for each time series in the batch (shape: (batch_size)). 3. ‘static’ (torch.Tensor): Static features of shape (batch_size, n_static_features).

Returns:

The latent space embedding computed by merging the processed time series and static features after passing them through respective dense layers and nonlinear activations.

Return type:

torch.Tensor

class oracle.architectures.GRU_MD_MM(output_dim, ts_feature_dim=5, static_feature_dim=18)

Bases: Hierarchical_classifier

GRU-based neural network architecture with multi-dimensional static features and multi-modal inputs for hierarchical classification.

forward(batch)

Compute the network’s output logits from an input batch. This method first extracts latent space embeddings from the batch, applies a ReLU activation to introduce non-linearity, and then computes the final logits using a fully connected output layer.

Parameters:

batch – The input data batch from which latent embeddings are derived.

Returns:

The computed logits after processing the latent embeddings through the ReLU activation and linear layer.

Return type:

torch.Tensor

get_latent_space_embeddings(batch)

Compute latent space embeddings from the given batch of inputs.

Processes time-series, static, and postage stamp inputs to produce a latent representation. The function performs the following steps:

  • Packs the padded time-series data using the sequence lengths provided.

  • Processes the time-series data with a bidirectional GRU and extracts its last hidden state.

  • Processes the postage stamp data through a Swin-based module and a fully connected layer.

  • Transforms the GRU and static features through separate dense layers with non-linear activation.

  • Concatenates the processed time-series, static, and postage stamp features.

  • Passes the merged representation through additional dense layers with activation functions to yield the final latent space embedding.

Parameters:

batch (dict) – A dictionary containing: 1. “ts” (torch.Tensor): Time series data of shape (batch_size, seq_len, n_ts_features). 2. “length” (torch.Tensor): Lengths of each time series in the batch (batch_size, ). 3. “static” (torch.Tensor): Static features of shape (batch_size, n_static_features). 4. “postage_stamp” (torch.Tensor): Postage stamp data for additional processing.

Returns:

The latent space embeddings computed from the combined features.

Return type:

torch.Tensor

class oracle.architectures.Hierarchical_classifier(taxonomy: Taxonomy)

Bases: Module, Trainer, Tester

Base class for hierarchical classification architectures.

embed(table)

Embeds the provided table into the model’s latent space. :param table: object

The table data to be embedded. The expected type and structure of this data should be compatible with the model’s input requirements.

Raises:

NotImplementedError – This method is not implemented by default and is intended for pretrained models only.

get_latent_space_embeddings(batch)

Compute the latent space embeddings for a given batch of data.

This method should be implemented by subclasses to transform the input batch into a latent representation. The exact nature of the embedding (e.g., dimensionality, transformation mechanism) depends on the specific architecture.

Parameters:

batch (dictionary) – A batch of input data on which to compute the latent embeddings. The expected format and type should be defined by the implementing subclass.

Returns:

The latent space embeddings corresponding to the provided batch. The structure and type of the embeddings is determined by the subclass implementation.

Return type:

Any

Raises:

NotImplementedError – If the method is not implemented by the subclass.

predict(table)

Predict the label at each hierarchical level for the table.

Parameters:

table (astropy.table.Table) – Input data containing one or more rows.

Returns:

Mapping from hierarchical level (as returned by self.score) to the predicted class

label. For each level, self.score(table) is expected to return a pandas.DataFrame of shape (n_samples, n_classes) with class labels as columns; the predicted label is the column with the highest score for the first sample.

Return type:

dict

Raises:

Any exceptions raised by self.score or by numpy operations (e.g., if the score DataFrame is empty) will be propagated.

predict_class_probabilities(batch)

Predicts the class probabilities for a given batch of data. This method first computes the conditional probabilities for the batch using the ‘predict_conditional_probabilities’ method. It then leverages the taxonomy to convert these conditional probabilities into final class probabilities.

Parameters:

batch (dictionary) – The input data batch for which class probabilities are to be predicted.

Returns:

The computed class probabilities derived from the input batch.

Return type:

torch.tensor

predict_class_probabilities_df(batch)

Predict class probabilities for a given batch and return the results in a DataFrame.

This method performs the following steps: 1. Retrieves the list of taxonomy nodes ordered by level using the taxonomy’s get_level_order_traversal method. 2. Computes the class probabilities for the input batch via the predict_class_probabilities method. 3. Constructs and returns a pandas DataFrame with the computed probabilities, where each column corresponds to a taxonomy node in level order.

Parameters:

batch (dictionary) – The input data batch on which to perform predictions. The expected format and type of batch depend on the implementation details of the prediction model.

Returns:

A DataFrame with columns representing taxonomy nodes and each cell containing

Return type:

pandas.DataFrame

the predicted class probability for the corresponding node.

predict_conditional_probabilities(batch)

Compute conditional probabilities from the model’s output. This method performs a forward pass using the given batch of inputs to obtain the logits. It then utilizes the taxonomy’s get_conditional_probabilities method to convert these logits into conditional probabilities. The resulting tensor is detached from the computation graph and returned.

Parameters:

batch (dictionary) – A batch of input data to be fed into the model. The exact format and type of the batch depend on the requirements of the forward method.

Returns:

A tensor representing the conditional probabilities computed from the model’s logits.

Return type:

torch.tensor

predict_conditional_probabilities_df(batch)

Predict conditional probabilities for a batch and return them as a pandas DataFrame.

This method retrieves a level-order traversal of nodes from the taxonomy, computes the conditional probabilities for the given batch using the predict_conditional_probabilities method, and then constructs a pandas DataFrame with the probabilities. The DataFrame’s columns are named using the node order from the taxonomy’s level-order traversal.

Parameters:

batch (dictionary) – The input data batch for which the conditional probabilities are to be predicted.

Returns:

A DataFrame containing the predicted conditional probabilities with columns corresponding to the level-order nodes.

Return type:

pandas.DataFrame

predict_full_scores(table)
score(table)

Compute hierarchical scores for the input table. Predicts scores for all taxonomy nodes using self.predict_full_scores(), then groups those scores by taxonomy depth and returns a mapping from depth levels to DataFrames containing the corresponding node scores.

Parameters:

table (astropy.table.Table) – Input observations/features to be scored.

Returns:

A mapping from taxonomy depth level to a

DataFrame of predicted scores for nodes at that level. Each DataFrame is a subset of the full prediction DataFrame containing only the columns for the nodes at that depth.

Return type:

dict[int, pandas.DataFrame]

Raises:

KeyError – If expected node columns (from self.taxonomy.get_nodes_by_depth()) are not present in the DataFrame returned by predict_full_scores().

oracle.constants module

Module containing constant mappings and configurations for the ORACLE project.

oracle.loss module

Top level module for defining the Weighted Hierarchical Cross Entropy Loss function for hierarchical classification tasks.

class oracle.loss.WHXE_Loss(taxonomy: Taxonomy, labels, alpha=0.5, beta=1)

Bases: Module

Implementation of the Weighted Hierarchical Cross Entropy Loss function.

compute_lambda_term()

Compute the lambda term for node weighting.

This method calculates the secondary weighting term using an exponential decay based on the node depths. The decay is controlled by the attribute ‘alpha’. The resulting lambda term emphasizes different levels of the tree according to their depth.

Returns:

None

Side Effects:
  • Sets self.lambda_term to a PyTorch tensor of shape (N_nodes) containing the computed values.

forward(logits, true, epsilon=1e-10)

Compute the hierarchical loss using the pseudo probabilities from masked softmaxes based on a taxonomy structure.

Parameters:
  • logits (torch.Tensor) – The raw output logits from the model for a batch.

  • true (torch.Tensor) – A tensor containing the indicator values for the true class labels.

  • epsilon (float, optional) – A small constant to prevent logarithm of zero; defaults to 1e-10.

Returns:

A scalar tensor representing the averaged hierarchical loss over the batch.

Return type:

torch.Tensor

get_class_weights(true)

Computes the class weights for each node in the taxonomy based on the true label data, using inverse frequency weighting.

Parameters:

true (torch.Tensor) – A binary tensor of shape (N_samples, N_nodes) where each row represents a sample and each column corresponds to a node in the taxonomy. An element should be 1 if the sample belongs to the class represented by the node, and 0 otherwise.

Returns:

A 1D tensor of shape (N_nodes,) containing the computed class weights.

Return type:

torch.Tensor

oracle.presets module

Top-level module for defining presets and utility functions for model selection, data loading, and training configurations in the ORACLE project.

oracle.presets.get_class_weights(labels)

Calculates the weights for each class using inverse frequency weighting.

Parameters:

labels (array-like) – An iterable of labels from which unique classes and their counts are derived.

Returns:

A dictionary where keys are the unique class labels and values are their corresponding weights (1/count).

Return type:

dict

oracle.presets.get_model(model_choice)

Retrieves and instantiates a model based on the provided model choice.

Parameters:

model_choice (str) –

A string identifier for the desired model configuration. Valid options and their corresponding behaviors are:

  • ”BTS-lite”: Uses BTS_Taxonomy to instantiate a GRU model.

  • ”BTS”: Uses BTS_Taxonomy to instantiate a GRU_MD model with a static feature dimension of 30.

  • ”ZTF_Sims-lite”: Uses BTS_Taxonomy to instantiate a GRU model.

  • ”ELAsTiCC-lite”: Uses ORACLE_Taxonomy to instantiate a GRU model.

  • ”ELAsTiCC”: Uses ORACLE_Taxonomy to instantiate a GRU_MD model with a static feature dimension of 18.

Returns:

The model instance associated with the given model_choice.

oracle.presets.get_test_loaders(model_choice, batch_size, max_n_per_class, days_list, excluded_classes=[], mapper=None)

Generates and returns a list of test DataLoaders configured according to the specified model type and parameters. :param model_choice: Specifies which model and corresponding dataset to use. Supported values include

“BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, and “ELAsTiCC”.

Parameters:
  • batch_size (int) – The size of batches to generate from each DataLoader.

  • max_n_per_class (int) – The maximum number of samples to include per class in the dataset.

  • days_list (list) – A list of day values used to dynamically configure a transformation (truncating light curves) applied to the test dataset.

  • excluded_classes (list, optional) – A list of class labels to exclude from the dataset. Defaults to an empty list.

  • mapper (dict, optional) – An optional dictionary used to map or modify dataset labels/structure (only used for the “BTS” model). Defaults to None.

Returns:

A list of DataLoader objects, each configured with a custom transformation based on a corresponding day from days_list.

Return type:

List[DataLoader]

Note

Each DataLoader is constructed for a specific truncation of the light curve based on the day value. The generator for DataLoader shuffling is explicitly created on the CPU.

oracle.presets.get_train_loader(model_choice, batch_size, max_n_per_class, excluded_classes=[])

Generates a DataLoader for training based on the provided model choice and dataset configuration.

Parameters:
  • model_choice (str) – The identifier for the model and corresponding dataset. Supported values include: “BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, and “ELAsTiCC”.

  • batch_size (int) – The number of samples per batch to load.

  • max_n_per_class (int) – The maximum number of samples to include per class in the dataset.

  • excluded_classes (list, optional) – A list of class identifiers to exclude from the dataset. Defaults to an empty list.

Returns:

A tuple containing:
  • DataLoader(torch.utils.data.DataLoader): A DataLoader instance for iterating through the training dataset.

  • list(list): A list of all labels present in the training dataset.

Return type:

tuple

oracle.presets.get_val_loader(model_choice, batch_size, val_truncation_days, max_n_per_class, excluded_classes=[])

Creates a DataLoader for the validation dataset along with its corresponding labels based on the specified model choice.

This function selects a dataset and transformation based on the model_choice provided and then concatenates the datasets (one for each truncation day specified in val_truncation_days). It returns a DataLoader constructed with the concatenated dataset and the set of all labels retrieved from the first dataset instance.

Parameters:
  • model_choice (str) – The name of the model variant to use. Valid options include: “BTS-lite”, “BTS”, “ZTF_Sims-lite”, “ELAsTiCC-lite”, “ELAsTiCC”.

  • batch_size (int) – The number of samples per batch in the returned DataLoader.

  • val_truncation_days (list) – A list of days used to truncate the light curves; each day corresponds to a transformation applied to the dataset.

  • max_n_per_class (int) – The maximum number of samples to include per class in the dataset.

  • excluded_classes (list, optional) – A list of classes to be excluded from the dataset. Defaults to an empty list.

Returns:

A tuple containing:
  • DataLoader: The DataLoader for the concatenated validation dataset with the specified batch size and collate function, constructed with a CPU-based torch.Generator.

  • list: A list of validation labels obtained from the first dataset in the list via get_all_labels().

Return type:

tuple

oracle.presets.worker_init_fn(worker_id)

Ensure proper random seeding in each worker process.

oracle.taxonomies module

Top level module for defining taxonomies used in hierarchical classification tasks.

class oracle.taxonomies.BTS_Taxonomy(**attr)

Bases: Taxonomy

Class to represent the BTS taxonomy as a directed graph.

class oracle.taxonomies.ORACLE_Taxonomy(**attr)

Bases: Taxonomy

Class to represent the ORACLE ELAsTiCC taxonomy as a directed graph.

class oracle.taxonomies.Taxonomy(directed=True, **attr)

Bases: DiGraph

Class to represent a taxonomy as a directed graph.

get_class_probabilities(conditional_probabilities)

Compute the class probabilities for each node in the taxonomy using the given conditional probabilities. The method takes a tensor of conditional probabilities (the output of a model), and, for each node in the taxonomy (except the root), computes its class probability as the product of the conditional probabilities along the unique path from the root to that node.

Parameters:

conditional_probabilities (torch.Tensor) – A tensor of shape (N, M) where N is the number of samples and M is the number of nodes in the taxonomy. Each element represents the conditional probability of a node given its parent node.

Returns:

A tensor of the same shape as conditional_probabilities where each element represents the computed class probability for the corresponding node. The class probability is calculated as the product of the conditional probabilities along the path from the root to that node.

Return type:

torch.Tensor

Raises:

AssertionError – If the number of columns in conditional_probabilities does not match the number of nodes in the taxonomy.

get_conditional_probabilities(logits, epsilon=1e-10)

Compute conditional probabilities over sibling groups of logits using masked softmax normalization. This function iterates over a set of binary masks corresponding to sibling groups and, for each mask, applies a softmax operation.This allows the function to compute conditional probabilities within each sibling group sequentially.

Parameters:
  • logits (torch.Tensor) – A tensor of logits of shape (batch_size, num_classes) on which to apply the masked softmax operations.

  • epsilon (float, optional) – A small constant added to the denominator to avoid division by zero during normalization. Default is 1e-10.

Returns:

The tensor of logits after applying conditional probability computations using masked softmax operations, where probabilities are normalized within the sibling groups.

Return type:

torch.Tensor

get_depths()

Compute the depths (number of edges from the root) for all nodes in the tree. This method performs a level-order traversal of the tree (using get_level_order_traversal) and calculates the depth of each node by finding the shortest path from the root node to the current node using networkx’s shortest_path function.

Returns:

An array where each element represents the depth of a node in the tree.

Return type:

numpy.ndarray

get_hierarchical_one_hot_encoding(labels)

Compute the hierarchical one-hot encoding for a set of taxonomy labels. This function creates a one-hot encoded representation for each label in the provided list, where each encoding corresponds to the nodes along the unique shortest path—from the root node to the label—in the taxonomy graph. The positions in the encoding vector are determined by a level-order traversal of the taxonomy.

Parameters:

labels (iterable) – An iterable of labels (nodes) to be encoded. Each label must exist in the taxonomy, as verified by the level-order traversal.

Returns:

A 2D numpy array of shape (number of labels, total number of nodes in the taxonomy), where each row is the one-hot encoded vector for the corresponding label. A value of 1 indicates that a node is part of the path from the root to the label; all other positions are 0.

Return type:

numpy.ndarray

Raises:

AssertionError – If any of the provided labels is not found within the taxonomy.

Note

  • The function assumes the taxonomy is represented as a graph and relies on NetworkX’s shortest path algorithm to determine the unique path from the root node to each label.

  • The ordering of the encoding vector is based on the level-order traversal of the taxonomy nodes.

get_leaf_nodes()

Return the leaf nodes in the taxonomy. This method identifies leaf nodes as those nodes with no outgoing edges (out_degree equals 0) and exactly one incoming edge (in_degree equals 1).

Returns:

A list containing all nodes that meet the criteria for being a leaf node.

Return type:

list

get_level_order_traversal()

Perform a level order traversal of the tree and return an ordering of its nodes.

This method initiates breadth-first search (BFS) starting from the root of the tree, using the networkx breadth-first search tree function. The traversal produces an iterable of nodes in the order they are visited.

Returns:

An iterable of node labels in level order starting from the root.

get_nodes_by_depth()

Return a dictionary mapping depths to nodes in the hierarchy. For each unique depth obtained from the hierarchy (via get_depths), this function:

  • Extracts the nodes corresponding to that depth from the level order traversal.

  • Stores them in a dictionary where keys are the depth levels.

  • Additionally, assigns the leaf nodes (from get_leaf_nodes) to the key -1.

Returns:

A dictionary where each key (of type int) corresponds to a depth level, and the associated value is a NumPy array containing the nodes at that depth. The key -1 is used to store a NumPy array of all leaf nodes.

Return type:

dict

Note

  • This function relies on the existence of helper methods:
    • get_depths(): to obtain the depth of each node.

    • get_level_order_traversal(): to obtain the nodes in level order.

    • get_leaf_nodes(): to obtain the list of leaf nodes.

get_parent_nodes()

Retrieves the parent node for each node in the taxonomy following a level order traversal.

Returns:

A list where each element is the parent of the corresponding node from the level order traversal. The root node will have an empty string as its parent.

Return type:

list

get_paths(labels)

Construct hierarchical paths from the given labels.

Parameters:

labels (iterable) – A collection of labels for which hierarchical paths are to be generated. The expected format should be compatible with the one-hot encoding produced by get_hierarchical_one_hot_encoding.

Returns:

A list where each element is a list representing the path (in level order) extracted from the hierarchy for the corresponding label.

Return type:

list of list

Note

  • This method relies on get_hierarchical_one_hot_encoding to obtain the binary encoded representation of the labels.

  • The ordering of nodes is determined by get_level_order_traversal, whose output is used to map encoded indexes to actual nodes.

get_sibling_masks()

Get the sibling masks for each node in the taxonomy using level order traversal.

Sibling masks indicate which nodes share the same parent. For each unique parent, a numpy array is created where each element is set to 1 if the corresponding node in the level order traversal has that parent, and 0 otherwise.

Returns:

A list of numpy arrays, each representing a mask for sibling nodes corresponding to one unique parent in the taxonomy.

Return type:

List[numpy.ndarray]

plot_colored_taxonomy(probabilities)

Plot the hierarchical taxonomy with colors corresponding to node probabilities.

This method computes a level-order traversal of the taxonomy, maps each node to its corresponding probability value for coloring, and then plots the taxonomy using Graphviz to determine node positions. The nodes are drawn with colors derived from the provided probabilities, and the plot is displayed using Matplotlib.

Parameters:

probabilities (array-like) – An array-like object (e.g., list, NumPy array, or tensor) containing probability values corresponding to each node in the taxonomy. The order of probabilities should match the order obtained from the level-order traversal.

Note

  • The method uses NetworkX for graph handling and drawing.

  • The Graphviz layout algorithm (‘dot’) is used to determine node positions.

  • Matplotlib functions are used to adjust layout and display the final plot.

plot_taxonomy()

Plots the taxonomy graph.

This method computes the layout of the taxonomy graph using Graphviz’s ‘dot’ algorithm, then draws the network using NetworkX’s drawing functionalities with labels, arrows, and custom styling. Finally, it displays the plot using Matplotlib.

oracle.test module

Interface for testing saved models in the ORACLE framework.

oracle.test.main()
oracle.test.parse_args()

Get commandline options

oracle.test.run_testing_loop(args)

Runs the testing loop for a specified model using the provided arguments.

This function performs the following steps:
  • Extracts key parameters (batch_size, max_n_per_class, and model directory) from the input arguments.

  • Reads the model choice from ‘train_args.csv’ located in the specified model directory.

  • Creates necessary subdirectories (‘plots’ and ‘reports’) within the model directory if they do not already exist.

  • Loads the appropriate model architecture based on the model choice and loads its pre-trained weights from ‘best_model_f1.pth’.

  • Sets up the model for testing and moves it to the designated device.

  • Retrieves test datasets for multiple default days (ignoring the ‘Anomaly’ class) and runs comprehensive analysis on each.

  • Retrieves additional test datasets (using an alternative mapping for anomalies) to generate embeddings for anomaly detection.

  • Generates and saves performance plots, including loss history and metrics across different phases.

  • Merges and displays performance tables based on a predefined list of thresholds.

Parameters:

args (Namespace) – An object containing the following attributes: 1. batch_size (int): The batch size for data loading. 2. max_n_per_class (int): The maximum number of samples per class for testing. 3. dir (str): The directory path where the model and related files are stored.

Returns:

None

Note

The function assumes that helper functions such as get_model, get_test_loaders, and model-specific methods (e.g., setup_testing, run_all_analysis, make_embeddings_for_AD, create_loss_history_plot, create_metric_phase_plots, merge_performance_tables) are defined elsewhere in the codebase.

oracle.tester module

Module for testing hierarchical models in the ORACLE framework.

class oracle.tester.Tester

Bases: object

Top-level class providing testing functionalities for hierarchical classification models.

create_classification_report(y_true, y_pred, file_name=None)

Generates a classification report comparing true and predicted labels, and optionally writes the report to a CSV file. This method first filters the input arrays to include only entries with a non-None true label. It then computes the classification report using scikit-learn’s classification_report function. If a file name is provided, it also exports the detailed report as a CSV file.

Parameters:
  • y_true (array-like) – Array of true labels. Only entries where the label is not None will be considered.

  • y_pred (array-like) – Array of predicted labels, corresponding to y_true.

  • file_name (str, optional) – The file path where the CSV report will be saved. If None, the CSV file is not generated.

Returns:

A text summary of the classification report.

Return type:

str

create_loss_history_plot()

Create and save a plot of the training and validation loss history.

Note

  • The numpy files must exist in the specified directory.

  • The ‘plot_train_val_history’ function must be properly defined and accessible.

create_metric_phase_plots()

Generates phase plots for key evaluation metrics across all experimental phases. This method iterates over a predefined list of metrics (‘f1-score’, ‘precision’, ‘recall’), retrieving the corresponding metric values across different phases by invoking the get_metric_over_all_phases method. For each metric, it then generates two types of plots:

  1. Class-wise performance over all phases using plot_class_wise_performance_over_all_phases.

  2. Level-averaged performance over all phases using plot_average_performance_over_all_phases.

The plots are saved to the directory specified by the model_dir attribute.

get_metric_over_all_phases(metric)

Calculates and aggregates the specified metric (f1-score, precision, or recall) across all non-root taxonomy depths.

Parameters:

metric (str) – The name of the metric to process. Must be one of [‘f1-score’, ‘precision’, ‘recall’].

Returns:

A dictionary where each key is a taxonomy depth (int) and each value is a pandas DataFrame containing the day-wise aggregated metric data.

Return type:

dict

Raises:

AssertionError – If the provided metric is not one of the accepted values.

make_embeddings_for_AD(test_loader, d)

Generate latent space embeddings for anomaly detection (AD) analysis and save the results. This method processes the test dataset by running model inference, extracting latent embeddings, and gathering the corresponding class labels and identifiers. It then creates a UMAP plot of the embeddings and saves both the plot and a CSV file containing the combined embedding data.

Parameters:
  • test_loader (iterable) – A DataLoader or iterable over the test dataset where each batch is a dictionary with keys ‘label’, ‘raw_label’, ‘id’, and any tensor data needed for embedding generation.

  • d (int or float) – A parameter indicating the number of days (or a similar metric) used in naming the output files and plots.

merge_performance_tables(days)

Merge performance tables for specified days and print LaTeX formatted results.

Parameters:

days (iterable) – A collection (e.g., list) of identifiers representing different report days.

Returns:

None

Side Effects:

Outputs a LaTeX formatted table to standard output.

run_all_analysis(test_loader, d)

Run analysis on the test set and generate evaluation plots and reports. This method sets the model to evaluation mode and iterates over the test_loader to perform inference. It aggregates the predicted class probabilities and the corresponding true labels, translating them into a hierarchical format based on the taxonomy provided. For each depth level (excluding the root level), it computes:

  • The recovery of true labels for the corresponding hierarchy level.

  • Confusion matrices for recall and precision, saving the plots as PDF files.

  • ROC curves for the predicted probabilities, saving the plots as PDF files.

  • A classification report that is both printed on the console and saved as a CSV file.

Parameters:
  • test_loader (iterable) – An iterable (e.g., DataLoader) that yields batches of test data, where each batch is a dictionary containing tensors (and other values) including the key ‘label’.

  • d (int) – An integer representing the number of days used in the trigger, incorporated into the naming of output files.

Returns:

None

setup_testing(model_dir, device)

Sets up the testing environment by configuring the model directory and device used for testing.

Parameters:
  • model_dir (str) – The directory path where the model files are stored.

  • device (torch.device or str) – The device on which the model will run (e.g., CPU or GPU).

Returns:

None

oracle.train module

Interface for training models in the ORACLE framework.

oracle.train.get_wandb_run(args)

Initializes and returns a Weights & Biases (wandb) run with the specified configuration.

Parameters:

args – An object that must contain the following attributes: 1. num_epochs (int): The number of training epochs. 2. batch_size (int): The batch size to be used. 3. lr (float): The learning rate for training. 4. max_n_per_class (int): The maximum number of samples per class. 5. alpha (float): A hyperparameter used for controlling loss behavior. 6. gamma (float): A hyperparameter used for weighting. 7. dir (str): The directory path where the model should be saved. 8. model (str): The identifier for the chosen model architecture. 9. load_weights (str): The file path for the pretrained model weights, if any.

Returns:

A wandb run instance initialized with the given configuration, which logs metadata and hyperparameters.

oracle.train.main()
oracle.train.parse_args()

Get commandline options

oracle.train.run_training_loop(args)

Runs the training loop for the model using the specified configuration and dataset loaders.

This function performs the following steps:
  1. Extracts training configuration parameters (e.g., number of epochs, batch size, learning rate, model choice, etc.) from the args argument.

  2. Initializes the model based on the provided model choice.

  3. Retrieves the training and validation data loaders along with their corresponding labels.

  4. Initializes a logging run (using WandB) and sets up the directory for saving models and training arguments.

  5. Optionally loads a pretrained model’s weights if a valid path is provided.

  6. Moves the model to the appropriate device, sets up the training configuration (including hyperparameters such as alpha and gamma), and begins model training.

  7. After training, saves the model to WandB and finalizes the logging run.

Parameters:

args (argparse.Namespace) – An object containing all necessary configuration parameters and hyperparameters including: 1. num_epochs (int): Number of epochs to train the model. 2. batch_size (int): Size of the batches used in training and validation. 3. lr (float): Learning rate for the optimizer. 4. max_n_per_class (int): Maximum number of samples per class for the training data. 5. alpha (float): Hyperparameter used during training (specific purpose defined by model’s setup). 6. gamma (float): Hyperparameter used during training (specific purpose defined by model’s setup). 7. dir (str): Directory path for saving the model and other related artifacts. 8. model (str): Identifier to select which model architecture to use. 9. load_weights (str or None): Path to pretrained model weights. If provided, these weights are loaded into the model.

Returns:

None

oracle.train.save_args_to_csv(args, filepath)

Save command-line arguments to a CSV file.

This function converts the attributes of an object, typically parsed from command-line input, into a single-row pandas DataFrame, and saves it to a CSV file at the specified filepath.

Parameters:
  • args (object) – An object containing attributes to be saved, often created using argparse.

  • filepath (str) – The file path (including filename) where the CSV file will be written.

Returns:

None

oracle.trainer module

Module for training hierarchical models in the ORACLE framework.

class oracle.trainer.EarlyStopper(patience=1, min_delta=0)

Bases: object

EarlyStopper class for monitoring validation loss and triggering early stopping during training.

patience

Number of consecutive epochs with insufficient improvement allowed before stopping early.

Type:

int

min_delta

Minimum change in validation loss to be considered an improvement.

Type:

float

counter

Counts the number of consecutive epochs without sufficient improvement.

Type:

int

min_validation_loss

The lowest validation loss observed so far.

Type:

float

early_stop(validation_loss)

Checks if training should be stopped early based on the current validation loss. This method compares the provided validation loss against the minimum validation loss seen so far.

  • If the validation loss is less than the minimum, it updates the minimum value and resets the counter.

  • If the validation loss exceeds the minimum by more than a specified delta, the counter increments.

  • When the counter reaches or exceeds the patience threshold, the method indicates that early stopping is warranted.

Parameters:

validation_loss (float) – The current validation loss from the evaluation phase.

Returns:

True if early stopping criterion is met (i.e., the counter has reached the patience limit), otherwise False.

Return type:

bool

class oracle.trainer.Trainer

Bases: object

Top-level class providing training functionalities for hierarchical classification models.

fit(train_loader, val_loader, num_epochs=5)

Train the model for a specified number of epochs.

This method moves the model to the designated device and iterates over the given number of epochs. During each epoch, it trains the model on the training data and evaluates it on the validation data. It records the training loss, validation loss, and macro F1 score, and saves the best models based on the lowest validation loss and highest F1 score. Additionally, the method logs metrics to an external service (e.g., Weights and Biases), updates the learning rate scheduler, saves the loss histories, and stops training early if an early stopping condition is met.

Parameters:
  • train_loader (DataLoader) – DataLoader providing batches of training data.

  • val_loader (DataLoader) – DataLoader providing batches of validation data.

  • num_epochs (int, optional) – Number of epochs to train for. Defaults to 5.

Returns:

None

log_data_in_wandb(train_loss_history, val_loss_history, f1_history, cf)

Logs training and validation metrics to Weights and Biases (wandb).

Parameters:
  • train_loss_history (List[float]) – List of training loss values recorded per epoch.

  • val_loss_history (List[float]) – List of validation loss values recorded per epoch.

  • f1_history (List[float]) – List of f1 score values recorded per epoch.

  • cf (Any) – A configuration parameter or additional information to log.

Returns:

None

save_loss_history(train_loss_history, val_loss_history, f1_history)

Saves the loss and F1 score histories as NumPy binary files in the model directory.

Parameters:
  • train_loss_history (list or array-like) – A collection of training loss values.

  • val_loss_history (list or array-like) – A collection of validation loss values.

  • f1_history (list or array-like) – A collection of F1 score values.

Each history is converted to a NumPy array before saving.

save_model_in_wandb()

Saves training artifacts to Weights & Biases (wandb) for experiment tracking. Each file is expected to reside in the directory specified by self.model_dir.

setup_training(alpha, gamma, lr, train_labels, val_labels, model_dir, device, wandb_run)

Set up the training components for the model. This method configures the training environment by initializing both the training and validation loss criteria, setting up the optimizer, scheduling learning rate adjustments, and configuring early stopping. It also assigns various parameters such as the device to use, the model directory for saving checkpoints, and the Weights & Biases run instance.

Parameters:
  • alpha (float) – Hyperparameter for the training loss function to adjust the influence of the taxonomy.

  • gamma (float) – Hyperparameter for the loss function that modulates the weighting of different classes.

  • lr (float) – Learning rate for the optimizer.

  • train_labels (iterable) – Labels for the training dataset, used to compute class weights in the training loss.

  • val_labels (iterable) – Labels for the validation dataset, used to compute class weights in the validation loss.

  • model_dir (str) – Directory path where model checkpoints and related outputs will be stored.

  • device (torch.device or str) – The computing device (CPU or GPU) on which to run the model.

  • wandb_run – Instance of the Weights & Biases run for logging training progress.

Returns:

None

train_one_epoch(train_loader)

Train the model for one epoch.

This method performs a full training cycle over all batches provided by the train_loader. For each batch, it moves the batch data to the appropriate device, converts labels into a hierarchical one-hot encoding using the taxonomy, performs a forward pass to compute logits, calculates the loss using the train criterion, and applies backpropagation along with an optimizer step.

Parameters:

train_loader (iterable) – A data loader that yields batches of training data. Each batch is expected to be a dictionary where the key ‘label’ is used to obtain the correct one-hot encoding.

Returns:

The mean loss value computed over all batches during the epoch.

Return type:

float

validate_one_epoch(val_loader)

Performs validation for one epoch.

This method sets the model to evaluation mode and iterates over the validation data loader, performing forward passes without gradient computation. It computes the loss using a specified validation criterion and aggregates the true and predicted labels for metric calculation. Additionally, it constructs a confusion matrix visualization and computes the macro F1 score.

Parameters:

val_loader (torch.utils.data.DataLoader) – A data loader for the validation dataset. Each batch should be a dictionary with a ‘label’ key among others and may include tensors that need to be moved to the device.

Returns:

A dictionary containing:
  • ’val_loss’ (float): The average loss computed over all validation batches.

  • ’macro_f1’ (float): The macro F1 score calculated using aggregated true and predicted labels.

  • ’cf’ (wandb.Image): A wandb.Image object representing the confusion matrix plot.

Return type:

dict

Note

  • The method uses the taxonomy provided by self.taxonomy to filter and encode the labels hierarchically.

  • Leaf node indices are determined using the taxonomy’s level order traversal.

  • All computations are performed with gradients disabled using torch.no_grad().

oracle.visualization module

Module for visualization functions in the ORACLE framework.

oracle.visualization.plot_average_performance_over_all_phases(metric, metrics_dictionary, model_dir=None)

Plot the average performance over all phases for the specified metric.

This function iterates over a dictionary of metrics grouped by different depths. For each depth, it extracts the row corresponding to a specified metric (e.g., ‘macro avg’) from a pandas DataFrame, plots the metric values against the days from first detection on a logarithmic x-scale, and either displays the plot interactively or saves it to a file in the specified directory.

Parameters:
  • metric (str) – The performance metric to be plotted. This is used to label the y-axis and is printed along with the plot data.

  • metrics_dictionary (dict) – A dictionary where each key corresponds to a depth level and each value is a pandas DataFrame. The DataFrame should contain metric rows (e.g., ‘macro avg’) with its index representing days from first detection and the row values representing the metric values.

  • model_dir (str, optional) – Directory in which to save the generated plot as a PDF file under the subdirectory ‘plots’. If None, the plot is displayed interactively using plt.show(). Defaults to None.

Note

  • All existing matplotlib figures are closed at the beginning to prevent overlapping.

  • The plot style is set to ‘default’.

  • The x-axis uses a logarithmic scale and its ticks are set based on the days from first detection.

  • A legend is added to the lower right of the plot.

oracle.visualization.plot_class_wise_performance_over_all_phases(metric, metrics_dictionary, model_dir=None)

Plots class-wise performance over all phases for a given metric.

This function iterates over each depth level present in the metrics_dictionary, extracts class-specific metric values (ignoring summary rows such as ‘accuracy’, ‘macro avg’, and ‘weighted avg’), and plots these values against the days from the first detection. The x-axis is set to a logarithmic scale.

Parameters:
  • metric (str) – The name of the performance metric to be displayed on the y-axis.

  • metrics_dictionary (dict) – A dictionary where each key represents a depth level and its corresponding value is a pandas DataFrame. The DataFrame should have its rows indexed by class names (with some entries like ‘accuracy’, ‘macro avg’, and ‘weighted avg’ to be skipped) and columns representing days from the first detection.

  • model_dir (str or None, optional) – The directory path where the plot PDFs will be saved. If provided, each plot is saved as ‘class_wise_{metric}.pdf’ in a subdirectory.

oracle.visualization.plot_confusion_matrix(y_true, y_pred, labels, normalize='true', title=None, img_file=None)

Plot a confusion matrix using the given true and predicted labels and display it with matplotlib.

Parameters:
  • y_true (array-like) – Array of true labels.

  • y_pred (array-like) – Array of predicted labels corresponding to y_true.

  • labels (list) – List of label names to be used in the confusion matrix.

  • normalize (str, optional) – Normalization mode for the confusion matrix. Default is ‘true’. Accepted values are typically ‘true’, ‘pred’, or ‘all’.

  • title (str, optional) – Title of the plot; if provided, it will be set on the plot.

  • img_file (str, optional) – File path to save the generated plot image. If None, the plot is not saved.

Returns:

None

Note

  • The function filters out any entries where the true label is None.

  • It adjusts the figure size and text properties based on the number of classes.

  • The function closes all previous matplotlib figures at the start and closes the plot at the end.

oracle.visualization.plot_plain_cf(y_true, y_pred, labels, normalize='true', title=None, img_file=None)

Plot a plain confusion matrix visualization based on true and predicted labels. This function computes and displays a confusion matrix for the provided true and predicted labels using a predefined style. It only considers entries in y_true that are not None. The resulting confusion matrix is displayed without tick marks or spines, and can optionally be saved to an image file.

Parameters:
  • y_true (array-like) – Array of true labels. Only the elements that are not None are used.

  • y_pred (array-like) – Array of predicted labels corresponding to y_true.

  • labels (array-like) – The set of labels to index the confusion matrix.

  • normalize (str, optional) – Normalization method for the confusion matrix (e.g., ‘true’). Defaults to ‘true’.

  • title (str, optional) – Title of the plot. (Currently not utilized in the function.)

  • img_file (str, optional) – If provided, the plot is saved to this file path.

Returns:

None

oracle.visualization.plot_roc_curves(probs_true, probs_pred, labels, title=None, img_file=None)

Plot ROC curves for each class and compute the macro-average ROC curve.

Parameters:
  • probs_true (ndarray) – A 2D array of shape (n_samples, n_classes) containing the ground truth binary labels for each class. Rows with all-zero values (indicating missing true labels) are removed before plotting.

  • probs_pred (ndarray) – A 2D array of shape (n_samples, n_classes) containing the predicted probabilities for each class.

  • labels (list of str) – A list of class labels corresponding to the columns in probs_true and probs_pred.

  • title (str, optional) – Title of the ROC plot. Defaults to None.

  • img_file (str, optional) – File path to save the plot image. If None, the plot is not saved to a file.

Returns:

None

Note

  • The ROC curves are plotted with an equal aspect ratio, and a legend is included showing the AUC for each class along with the macro-average AUC.

oracle.visualization.plot_train_val_history(train_loss_history, val_loss_history, file_name)

Plot the training and validation loss curves along with their rolling averages on a logarithmic scale.

Parameters:
  • train_loss_history (list or array-like) – The history of training loss values.

  • val_loss_history (list or array-like) – The history of validation loss values.

  • file_name (str) – The file path where the generated plot will be saved.

oracle.visualization.plot_umap(embeddings, classes, bts_classes, id, d, model_dir=None)

Plot UMAP projection of embeddings. This function computes a 2D UMAP projection from high-dimensional embeddings and generates both a static scatter plot using matplotlib and an interactive scatter plot using plotly. Points in the plots are colored based on the provided classes. If a model directory is specified, the plots are saved to disk; otherwise, the static plot is displayed.

Parameters:
  • embeddings (array-like) – High-dimensional feature data (e.g., a numpy array) with shape (n_samples, n_features).

  • classes (array-like) – Class labels for each embedding, used for color-coding the points in the plot.

  • id (array-like) – Unique identifiers for each source, used for hover information in the interactive plot.

  • bts_classes (array-like) – Additional class information for tooltips in the interactive plot.

  • d (int or str) – Identifier (e.g., number of days or a delay parameter) used in the plot title and file names.

  • model_dir (str, optional) – Directory where the plots will be saved. If None, the static plot is shown instead.

Returns:

None

Raises:

Exceptions from UMAP, matplotlib, or plotly if issues occur during the transformation or plotting process.

Module contents

ORACLE is a general framework for hierarchical classification.