oracle.visualization

Module for visualization functions in the ORACLE framework.

Functions

plot_average_performance_over_all_phases(...)

Plot the average performance over all phases for the specified metric.

plot_class_wise_performance_over_all_phases(...)

Plots class-wise performance over all phases for a given metric.

plot_confusion_matrix(y_true, y_pred, labels)

Plot a confusion matrix using the given true and predicted labels and display it with matplotlib.

plot_plain_cf(y_true, y_pred, labels[, ...])

Plot a plain confusion matrix visualization based on true and predicted labels.

plot_roc_curves(probs_true, probs_pred, labels)

Plot ROC curves for each class and compute the macro-average ROC curve.

plot_train_val_history(train_loss_history, ...)

Plot the training and validation loss curves along with their rolling averages on a logarithmic scale.

plot_umap(embeddings, classes, bts_classes, ...)

Plot UMAP projection of embeddings.