oracle.loss module
Top level module for defining the Weighted Hierarchical Cross Entropy Loss function for hierarchical classification tasks.
- class oracle.loss.WHXE_Loss(taxonomy: Taxonomy, labels, alpha=0.5, beta=1)
Bases:
ModuleImplementation of the Weighted Hierarchical Cross Entropy Loss function.
- compute_lambda_term()
Compute the lambda term for node weighting.
This method calculates the secondary weighting term using an exponential decay based on the node depths. The decay is controlled by the attribute ‘alpha’. The resulting lambda term emphasizes different levels of the tree according to their depth.
- Returns:
None
- Side Effects:
Sets self.lambda_term to a PyTorch tensor of shape (N_nodes) containing the computed values.
- forward(logits, true, epsilon=1e-10)
Compute the hierarchical loss using the pseudo probabilities from masked softmaxes based on a taxonomy structure.
- Parameters:
logits (torch.Tensor) – The raw output logits from the model for a batch.
true (torch.Tensor) – A tensor containing the indicator values for the true class labels.
epsilon (float, optional) – A small constant to prevent logarithm of zero; defaults to 1e-10.
- Returns:
A scalar tensor representing the averaged hierarchical loss over the batch.
- Return type:
torch.Tensor
- get_class_weights(true)
Computes the class weights for each node in the taxonomy based on the true label data, using inverse frequency weighting.
- Parameters:
true (torch.Tensor) – A binary tensor of shape (N_samples, N_nodes) where each row represents a sample and each column corresponds to a node in the taxonomy. An element should be 1 if the sample belongs to the class represented by the node, and 0 otherwise.
- Returns:
A 1D tensor of shape (N_nodes,) containing the computed class weights.
- Return type:
torch.Tensor