oracle.taxonomies module

Top level module for defining taxonomies used in hierarchical classification tasks.

class oracle.taxonomies.BTS_Taxonomy(**attr)

Bases: Taxonomy

Class to represent the BTS taxonomy as a directed graph.

class oracle.taxonomies.ORACLE_Taxonomy(**attr)

Bases: Taxonomy

Class to represent the ORACLE ELAsTiCC taxonomy as a directed graph.

class oracle.taxonomies.Taxonomy(directed=True, **attr)

Bases: DiGraph

Class to represent a taxonomy as a directed graph.

get_class_probabilities(conditional_probabilities)

Compute the class probabilities for each node in the taxonomy using the given conditional probabilities. The method takes a tensor of conditional probabilities (the output of a model), and, for each node in the taxonomy (except the root), computes its class probability as the product of the conditional probabilities along the unique path from the root to that node.

Parameters:

conditional_probabilities (torch.Tensor) – A tensor of shape (N, M) where N is the number of samples and M is the number of nodes in the taxonomy. Each element represents the conditional probability of a node given its parent node.

Returns:

A tensor of the same shape as conditional_probabilities where each element represents the computed class probability for the corresponding node. The class probability is calculated as the product of the conditional probabilities along the path from the root to that node.

Return type:

torch.Tensor

Raises:

AssertionError – If the number of columns in conditional_probabilities does not match the number of nodes in the taxonomy.

get_conditional_probabilities(logits, epsilon=1e-10)

Compute conditional probabilities over sibling groups of logits using masked softmax normalization. This function iterates over a set of binary masks corresponding to sibling groups and, for each mask, applies a softmax operation.This allows the function to compute conditional probabilities within each sibling group sequentially.

Parameters:
  • logits (torch.Tensor) – A tensor of logits of shape (batch_size, num_classes) on which to apply the masked softmax operations.

  • epsilon (float, optional) – A small constant added to the denominator to avoid division by zero during normalization. Default is 1e-10.

Returns:

The tensor of logits after applying conditional probability computations using masked softmax operations, where probabilities are normalized within the sibling groups.

Return type:

torch.Tensor

get_depths()

Compute the depths (number of edges from the root) for all nodes in the tree. This method performs a level-order traversal of the tree (using get_level_order_traversal) and calculates the depth of each node by finding the shortest path from the root node to the current node using networkx’s shortest_path function.

Returns:

An array where each element represents the depth of a node in the tree.

Return type:

numpy.ndarray

get_hierarchical_one_hot_encoding(labels)

Compute the hierarchical one-hot encoding for a set of taxonomy labels. This function creates a one-hot encoded representation for each label in the provided list, where each encoding corresponds to the nodes along the unique shortest path—from the root node to the label—in the taxonomy graph. The positions in the encoding vector are determined by a level-order traversal of the taxonomy.

Parameters:

labels (iterable) – An iterable of labels (nodes) to be encoded. Each label must exist in the taxonomy, as verified by the level-order traversal.

Returns:

A 2D numpy array of shape (number of labels, total number of nodes in the taxonomy), where each row is the one-hot encoded vector for the corresponding label. A value of 1 indicates that a node is part of the path from the root to the label; all other positions are 0.

Return type:

numpy.ndarray

Raises:

AssertionError – If any of the provided labels is not found within the taxonomy.

Note

  • The function assumes the taxonomy is represented as a graph and relies on NetworkX’s shortest path algorithm to determine the unique path from the root node to each label.

  • The ordering of the encoding vector is based on the level-order traversal of the taxonomy nodes.

get_leaf_nodes()

Return the leaf nodes in the taxonomy. This method identifies leaf nodes as those nodes with no outgoing edges (out_degree equals 0) and exactly one incoming edge (in_degree equals 1).

Returns:

A list containing all nodes that meet the criteria for being a leaf node.

Return type:

list

get_level_order_traversal()

Perform a level order traversal of the tree and return an ordering of its nodes.

This method initiates breadth-first search (BFS) starting from the root of the tree, using the networkx breadth-first search tree function. The traversal produces an iterable of nodes in the order they are visited.

Returns:

An iterable of node labels in level order starting from the root.

get_nodes_by_depth()

Return a dictionary mapping depths to nodes in the hierarchy. For each unique depth obtained from the hierarchy (via get_depths), this function:

  • Extracts the nodes corresponding to that depth from the level order traversal.

  • Stores them in a dictionary where keys are the depth levels.

  • Additionally, assigns the leaf nodes (from get_leaf_nodes) to the key -1.

Returns:

A dictionary where each key (of type int) corresponds to a depth level, and the associated value is a NumPy array containing the nodes at that depth. The key -1 is used to store a NumPy array of all leaf nodes.

Return type:

dict

Note

  • This function relies on the existence of helper methods:
    • get_depths(): to obtain the depth of each node.

    • get_level_order_traversal(): to obtain the nodes in level order.

    • get_leaf_nodes(): to obtain the list of leaf nodes.

get_parent_nodes()

Retrieves the parent node for each node in the taxonomy following a level order traversal.

Returns:

A list where each element is the parent of the corresponding node from the level order traversal. The root node will have an empty string as its parent.

Return type:

list

get_paths(labels)

Construct hierarchical paths from the given labels.

Parameters:

labels (iterable) – A collection of labels for which hierarchical paths are to be generated. The expected format should be compatible with the one-hot encoding produced by get_hierarchical_one_hot_encoding.

Returns:

A list where each element is a list representing the path (in level order) extracted from the hierarchy for the corresponding label.

Return type:

list of list

Note

  • This method relies on get_hierarchical_one_hot_encoding to obtain the binary encoded representation of the labels.

  • The ordering of nodes is determined by get_level_order_traversal, whose output is used to map encoded indexes to actual nodes.

get_sibling_masks()

Get the sibling masks for each node in the taxonomy using level order traversal.

Sibling masks indicate which nodes share the same parent. For each unique parent, a numpy array is created where each element is set to 1 if the corresponding node in the level order traversal has that parent, and 0 otherwise.

Returns:

A list of numpy arrays, each representing a mask for sibling nodes corresponding to one unique parent in the taxonomy.

Return type:

List[numpy.ndarray]

plot_colored_taxonomy(probabilities)

Plot the hierarchical taxonomy with colors corresponding to node probabilities.

This method computes a level-order traversal of the taxonomy, maps each node to its corresponding probability value for coloring, and then plots the taxonomy using Graphviz to determine node positions. The nodes are drawn with colors derived from the provided probabilities, and the plot is displayed using Matplotlib.

Parameters:

probabilities (array-like) – An array-like object (e.g., list, NumPy array, or tensor) containing probability values corresponding to each node in the taxonomy. The order of probabilities should match the order obtained from the level-order traversal.

Note

  • The method uses NetworkX for graph handling and drawing.

  • The Graphviz layout algorithm (‘dot’) is used to determine node positions.

  • Matplotlib functions are used to adjust layout and display the final plot.

plot_taxonomy()

Plots the taxonomy graph.

This method computes the layout of the taxonomy graph using Graphviz’s ‘dot’ algorithm, then draws the network using NetworkX’s drawing functionalities with labels, arrows, and custom styling. Finally, it displays the plot using Matplotlib.