oracle.visualization module

Module for visualization functions in the ORACLE framework.

oracle.visualization.plot_average_performance_over_all_phases(metric, metrics_dictionary, model_dir=None)

Plot the average performance over all phases for the specified metric.

This function iterates over a dictionary of metrics grouped by different depths. For each depth, it extracts the row corresponding to a specified metric (e.g., ‘macro avg’) from a pandas DataFrame, plots the metric values against the days from first detection on a logarithmic x-scale, and either displays the plot interactively or saves it to a file in the specified directory.

Parameters:
  • metric (str) – The performance metric to be plotted. This is used to label the y-axis and is printed along with the plot data.

  • metrics_dictionary (dict) – A dictionary where each key corresponds to a depth level and each value is a pandas DataFrame. The DataFrame should contain metric rows (e.g., ‘macro avg’) with its index representing days from first detection and the row values representing the metric values.

  • model_dir (str, optional) – Directory in which to save the generated plot as a PDF file under the subdirectory ‘plots’. If None, the plot is displayed interactively using plt.show(). Defaults to None.

Note

  • All existing matplotlib figures are closed at the beginning to prevent overlapping.

  • The plot style is set to ‘default’.

  • The x-axis uses a logarithmic scale and its ticks are set based on the days from first detection.

  • A legend is added to the lower right of the plot.

oracle.visualization.plot_class_wise_performance_over_all_phases(metric, metrics_dictionary, model_dir=None)

Plots class-wise performance over all phases for a given metric.

This function iterates over each depth level present in the metrics_dictionary, extracts class-specific metric values (ignoring summary rows such as ‘accuracy’, ‘macro avg’, and ‘weighted avg’), and plots these values against the days from the first detection. The x-axis is set to a logarithmic scale.

Parameters:
  • metric (str) – The name of the performance metric to be displayed on the y-axis.

  • metrics_dictionary (dict) – A dictionary where each key represents a depth level and its corresponding value is a pandas DataFrame. The DataFrame should have its rows indexed by class names (with some entries like ‘accuracy’, ‘macro avg’, and ‘weighted avg’ to be skipped) and columns representing days from the first detection.

  • model_dir (str or None, optional) – The directory path where the plot PDFs will be saved. If provided, each plot is saved as ‘class_wise_{metric}.pdf’ in a subdirectory.

oracle.visualization.plot_confusion_matrix(y_true, y_pred, labels, normalize='true', title=None, img_file=None)

Plot a confusion matrix using the given true and predicted labels and display it with matplotlib.

Parameters:
  • y_true (array-like) – Array of true labels.

  • y_pred (array-like) – Array of predicted labels corresponding to y_true.

  • labels (list) – List of label names to be used in the confusion matrix.

  • normalize (str, optional) – Normalization mode for the confusion matrix. Default is ‘true’. Accepted values are typically ‘true’, ‘pred’, or ‘all’.

  • title (str, optional) – Title of the plot; if provided, it will be set on the plot.

  • img_file (str, optional) – File path to save the generated plot image. If None, the plot is not saved.

Returns:

None

Note

  • The function filters out any entries where the true label is None.

  • It adjusts the figure size and text properties based on the number of classes.

  • The function closes all previous matplotlib figures at the start and closes the plot at the end.

oracle.visualization.plot_plain_cf(y_true, y_pred, labels, normalize='true', title=None, img_file=None)

Plot a plain confusion matrix visualization based on true and predicted labels. This function computes and displays a confusion matrix for the provided true and predicted labels using a predefined style. It only considers entries in y_true that are not None. The resulting confusion matrix is displayed without tick marks or spines, and can optionally be saved to an image file.

Parameters:
  • y_true (array-like) – Array of true labels. Only the elements that are not None are used.

  • y_pred (array-like) – Array of predicted labels corresponding to y_true.

  • labels (array-like) – The set of labels to index the confusion matrix.

  • normalize (str, optional) – Normalization method for the confusion matrix (e.g., ‘true’). Defaults to ‘true’.

  • title (str, optional) – Title of the plot. (Currently not utilized in the function.)

  • img_file (str, optional) – If provided, the plot is saved to this file path.

Returns:

None

oracle.visualization.plot_roc_curves(probs_true, probs_pred, labels, title=None, img_file=None)

Plot ROC curves for each class and compute the macro-average ROC curve.

Parameters:
  • probs_true (ndarray) – A 2D array of shape (n_samples, n_classes) containing the ground truth binary labels for each class. Rows with all-zero values (indicating missing true labels) are removed before plotting.

  • probs_pred (ndarray) – A 2D array of shape (n_samples, n_classes) containing the predicted probabilities for each class.

  • labels (list of str) – A list of class labels corresponding to the columns in probs_true and probs_pred.

  • title (str, optional) – Title of the ROC plot. Defaults to None.

  • img_file (str, optional) – File path to save the plot image. If None, the plot is not saved to a file.

Returns:

None

Note

  • The ROC curves are plotted with an equal aspect ratio, and a legend is included showing the AUC for each class along with the macro-average AUC.

oracle.visualization.plot_train_val_history(train_loss_history, val_loss_history, file_name)

Plot the training and validation loss curves along with their rolling averages on a logarithmic scale.

Parameters:
  • train_loss_history (list or array-like) – The history of training loss values.

  • val_loss_history (list or array-like) – The history of validation loss values.

  • file_name (str) – The file path where the generated plot will be saved.

oracle.visualization.plot_umap(embeddings, classes, bts_classes, id, d, model_dir=None)

Plot UMAP projection of embeddings. This function computes a 2D UMAP projection from high-dimensional embeddings and generates both a static scatter plot using matplotlib and an interactive scatter plot using plotly. Points in the plots are colored based on the provided classes. If a model directory is specified, the plots are saved to disk; otherwise, the static plot is displayed.

Parameters:
  • embeddings (array-like) – High-dimensional feature data (e.g., a numpy array) with shape (n_samples, n_features).

  • classes (array-like) – Class labels for each embedding, used for color-coding the points in the plot.

  • id (array-like) – Unique identifiers for each source, used for hover information in the interactive plot.

  • bts_classes (array-like) – Additional class information for tooltips in the interactive plot.

  • d (int or str) – Identifier (e.g., number of days or a delay parameter) used in the plot title and file names.

  • model_dir (str, optional) – Directory where the plots will be saved. If None, the static plot is shown instead.

Returns:

None

Raises:

Exceptions from UMAP, matplotlib, or plotly if issues occur during the transformation or plotting process.