oracle.architectures module
Top-level module for defining various neural network architectures for hierarchical classification.
- class oracle.architectures.GRU(taxonomy: Taxonomy, ts_feature_dim=5)
Bases:
Hierarchical_classifierGRU-based neural network architecture for hierarchical classification.
- forward(batch)
Compute the logits for a given input batch through the model.
This method processes the input batch by first obtaining its latent space embeddings, applying a ReLU activation, and then projecting the result to logits using a fully connected output layer.
- Parameters:
batch (dict) – Input data batch.
- Returns:
The computed output logits after the final linear transformation.
- Return type:
logits(torch.Tensor)
- get_latent_space_embeddings(batch)
Compute latent space embeddings for a batch of time-series data.
- Parameters:
batch (dict) – A dictionary containing: 1. ‘ts’ (torch.Tensor): Input time-series data of shape (batch_size, seq_len, n_ts_features). 2.’length’ (torch.Tensor or list[int]): Sequence lengths indicating the valid lengths of each time-series in the batch.
- Returns:
The latent space embeddings resulting from the network’s forward pass.
- Return type:
torch.Tensor
- class oracle.architectures.GRU_MD(taxonomy: Taxonomy, ts_feature_dim=5, static_feature_dim=30)
Bases:
Hierarchical_classifierGRU-based neural network architecture with multi-dimensional static features for hierarchical classification.
- forward(batch)
Computes the forward pass through the network module. This method performs the following steps:
Extracts latent space embeddings from the input batch using self.get_latent_space_embeddings.
Applies a ReLU activation to the latent embeddings.
Computes the logits by passing the activated embeddings through a fully connected layer (self.fc_out).
- Parameters:
batch – Input data (e.g., tensors or other data structures) used to compute the latent embeddings.
- Returns:
Logits produced by the network.
- Return type:
Tensor
- get_latent_space_embeddings(batch)
Generates the latent space embedding for a given batch of data by processing both time series and static features.
- Parameters:
batch (dict) – A dictionary containing the following keys: 1. ‘ts’ (torch.Tensor): Time series data of shape (batch_size, seq_len, n_ts_features). 2. ‘length’ (torch.Tensor): Tensor containing the true lengths for each time series in the batch (shape: (batch_size)). 3. ‘static’ (torch.Tensor): Static features of shape (batch_size, n_static_features).
- Returns:
The latent space embedding computed by merging the processed time series and static features after passing them through respective dense layers and nonlinear activations.
- Return type:
torch.Tensor
- class oracle.architectures.GRU_MD_MM(output_dim, ts_feature_dim=5, static_feature_dim=18)
Bases:
Hierarchical_classifierGRU-based neural network architecture with multi-dimensional static features and multi-modal inputs for hierarchical classification.
- forward(batch)
Compute the network’s output logits from an input batch. This method first extracts latent space embeddings from the batch, applies a ReLU activation to introduce non-linearity, and then computes the final logits using a fully connected output layer.
- Parameters:
batch – The input data batch from which latent embeddings are derived.
- Returns:
The computed logits after processing the latent embeddings through the ReLU activation and linear layer.
- Return type:
torch.Tensor
- get_latent_space_embeddings(batch)
Compute latent space embeddings from the given batch of inputs.
Processes time-series, static, and postage stamp inputs to produce a latent representation. The function performs the following steps:
Packs the padded time-series data using the sequence lengths provided.
Processes the time-series data with a bidirectional GRU and extracts its last hidden state.
Processes the postage stamp data through a Swin-based module and a fully connected layer.
Transforms the GRU and static features through separate dense layers with non-linear activation.
Concatenates the processed time-series, static, and postage stamp features.
Passes the merged representation through additional dense layers with activation functions to yield the final latent space embedding.
- Parameters:
batch (dict) – A dictionary containing: 1. “ts” (torch.Tensor): Time series data of shape (batch_size, seq_len, n_ts_features). 2. “length” (torch.Tensor): Lengths of each time series in the batch (batch_size, ). 3. “static” (torch.Tensor): Static features of shape (batch_size, n_static_features). 4. “postage_stamp” (torch.Tensor): Postage stamp data for additional processing.
- Returns:
The latent space embeddings computed from the combined features.
- Return type:
torch.Tensor
- class oracle.architectures.Hierarchical_classifier(taxonomy: Taxonomy)
Bases:
Module,Trainer,TesterBase class for hierarchical classification architectures.
- embed(table)
Embeds the provided table into the model’s latent space. :param table: object
The table data to be embedded. The expected type and structure of this data should be compatible with the model’s input requirements.
- Raises:
NotImplementedError – This method is not implemented by default and is intended for pretrained models only.
- get_latent_space_embeddings(batch)
Compute the latent space embeddings for a given batch of data.
This method should be implemented by subclasses to transform the input batch into a latent representation. The exact nature of the embedding (e.g., dimensionality, transformation mechanism) depends on the specific architecture.
- Parameters:
batch (dictionary) – A batch of input data on which to compute the latent embeddings. The expected format and type should be defined by the implementing subclass.
- Returns:
The latent space embeddings corresponding to the provided batch. The structure and type of the embeddings is determined by the subclass implementation.
- Return type:
Any
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
- predict(table)
Predict the label at each hierarchical level for the table.
- Parameters:
table (astropy.table.Table) – Input data containing one or more rows.
- Returns:
- Mapping from hierarchical level (as returned by self.score) to the predicted class
label. For each level, self.score(table) is expected to return a pandas.DataFrame of shape (n_samples, n_classes) with class labels as columns; the predicted label is the column with the highest score for the first sample.
- Return type:
dict
- Raises:
Any exceptions raised by self.score or by numpy operations (e.g., if the score DataFrame is empty) will be propagated. –
- predict_class_probabilities(batch)
Predicts the class probabilities for a given batch of data. This method first computes the conditional probabilities for the batch using the ‘predict_conditional_probabilities’ method. It then leverages the taxonomy to convert these conditional probabilities into final class probabilities.
- Parameters:
batch (dictionary) – The input data batch for which class probabilities are to be predicted.
- Returns:
The computed class probabilities derived from the input batch.
- Return type:
torch.tensor
- predict_class_probabilities_df(batch)
Predict class probabilities for a given batch and return the results in a DataFrame.
This method performs the following steps: 1. Retrieves the list of taxonomy nodes ordered by level using the taxonomy’s get_level_order_traversal method. 2. Computes the class probabilities for the input batch via the predict_class_probabilities method. 3. Constructs and returns a pandas DataFrame with the computed probabilities, where each column corresponds to a taxonomy node in level order.
- Parameters:
batch (dictionary) – The input data batch on which to perform predictions. The expected format and type of batch depend on the implementation details of the prediction model.
- Returns:
A DataFrame with columns representing taxonomy nodes and each cell containing
- Return type:
pandas.DataFrame
the predicted class probability for the corresponding node.
- predict_conditional_probabilities(batch)
Compute conditional probabilities from the model’s output. This method performs a forward pass using the given batch of inputs to obtain the logits. It then utilizes the taxonomy’s get_conditional_probabilities method to convert these logits into conditional probabilities. The resulting tensor is detached from the computation graph and returned.
- Parameters:
batch (dictionary) – A batch of input data to be fed into the model. The exact format and type of the batch depend on the requirements of the forward method.
- Returns:
A tensor representing the conditional probabilities computed from the model’s logits.
- Return type:
torch.tensor
- predict_conditional_probabilities_df(batch)
Predict conditional probabilities for a batch and return them as a pandas DataFrame.
This method retrieves a level-order traversal of nodes from the taxonomy, computes the conditional probabilities for the given batch using the predict_conditional_probabilities method, and then constructs a pandas DataFrame with the probabilities. The DataFrame’s columns are named using the node order from the taxonomy’s level-order traversal.
- Parameters:
batch (dictionary) – The input data batch for which the conditional probabilities are to be predicted.
- Returns:
A DataFrame containing the predicted conditional probabilities with columns corresponding to the level-order nodes.
- Return type:
pandas.DataFrame
- predict_full_scores(table)
- score(table)
Compute hierarchical scores for the input table. Predicts scores for all taxonomy nodes using self.predict_full_scores(), then groups those scores by taxonomy depth and returns a mapping from depth levels to DataFrames containing the corresponding node scores.
- Parameters:
table (astropy.table.Table) – Input observations/features to be scored.
- Returns:
- A mapping from taxonomy depth level to a
DataFrame of predicted scores for nodes at that level. Each DataFrame is a subset of the full prediction DataFrame containing only the columns for the nodes at that depth.
- Return type:
dict[int, pandas.DataFrame]
- Raises:
KeyError – If expected node columns (from self.taxonomy.get_nodes_by_depth()) are not present in the DataFrame returned by predict_full_scores().