oracle.pretrained.BTS module

Pretrained model(s) for the BTS dataset.

class oracle.pretrained.BTS.ORACLE1_BTS(model_dir='/Users/vedshah/Documents/Research/NU-Miller/Projects/Hierarchical-VT/models/BTS/lemon-spaceship-252')

Bases: GRU_MD

ORACLE1_BTS is a model class that inherits from GRU_MD designed to load a pre-trained BTS model and perform predictions on time series data augmented with static features. The model uses a hierarchical taxonomy to output predictions at multiple levels of granularity.

taxonomy

An instance of the taxonomy used to structure the class labels.

Type:

ORACLE_Taxonomy

ts_feature_dim

Dimensionality of the time series input features.

Type:

int

static_feature_dim

Dimensionality of the static input features.

Type:

int

model_dir

Directory path where the model weights are stored.

Type:

str

embed(table)

Embed a table into its latent space representation.

Parameters:

table – The input data (e.g., a table or structured data) to be embedded. The exact format is expected to be compatible with the make_batch method.

Returns:

A NumPy array containing the latent space embeddings corresponding to the input table.

Return type:

numpy.ndarray

make_batch(table)

Create a batch from the input table.

Parameters:

table (astropy.table.Table) – Input data containing one or more rows.

Returns:

A dictionary containing the batch data.

Return type:

dict

predict(table)

Predict the label at each hierarchical level for the table.

Parameters:

table (astropy.table.Table) – Input data containing one or more rows.

Returns:

Mapping from hierarchical level (as returned by self.score) to the predicted class label. For each level, self.score(table) is expected to return a pandas.DataFrame of shape (n_samples, n_classes) with class labels as columns; the predicted label is the column with the highest score for the first sample.

Return type:

dict

Raises:

Any exceptions raised by self.score or by numpy operations (e.g., if the score DataFrame is empty) will be propagated.

predict_full_scores(table)

Predict class probability scores for a single time-series table.

Prepares a single observation table for the model by calling augment_table and converting time-dependent and time-independent features into torch tensors. The input table must contain the columns ‘magpsf’, ‘sigmpdf’, ‘fid’, ‘jd’, and ‘photflag’.

Parameters:

table (astropy.table.Table) – Astropy Table containing time-series data.

Returns:

A DataFrame containing class probability scores for each class in the taxonomy.

Return type:

pd.DataFrame

Raises:

ValueError – If time-series columns have inconsistent lengths or if the table is empty in a way that the downstream model cannot handle.

score(table)

Compute hierarchical scores for the input table. Predicts scores for all taxonomy nodes using self.predict_full_scores(), then groups those scores by taxonomy depth and returns a mapping from depth levels to DataFrames containing the corresponding node scores. Taxonomy levels 0 are removed because they are irrelevant in the current taxonomy.

Parameters:

table (astropy.table.Table) – Input observations/features to be scored.

Returns:

A mapping from taxonomy depth level to a DataFrame of predicted scores for nodes at that level. Each DataFrame is a subset of the full prediction DataFrame containing only the columns for the nodes at that depth.

Return type:

dict[int, pandas.DataFrame]

Raises:

KeyError – If expected node columns (from self.taxonomy.get_nodes_by_depth()) are not present in the DataFrame returned by predict_full_scores().

Note

  • This is very similar to the model used for the original ORACLE paper.

oracle.pretrained.BTS.augment_table(table)

Augments a table by modifying its time-related and feature-specific values, and splitting it into two separate tables. This function performs the following modifications:

  • Converts filter IDs (‘fid’) to their corresponding mean wavelengths using the mapping ZTF_fid_to_wavelengths.

  • Normalizes the ‘jd’ column by subtracting the minimum jd value to set the starting time at zero.

  • Reorders the columns based on a predefined list time_dependent_feature_list to create a time-dependent table,

    and adds a constant column ‘flag’ with a value of 1.

  • Extracts a time-independent table based on a predefined list time_independent_feature_list.

Parameters:

table (pandas.DataFrame) – The input table containing astronomical observations with at least ‘fid’ and ‘jd’ columns.

Returns:

A tuple containing two pandas DataFrames:
  • lc_table: The time-dependent table with re-ordered columns and an additional ‘flag’ column.

  • static_table: The table containing the time-independent features.

Return type:

tuple

Notes

  • The function assumes that the variables time_dependent_feature_list, time_independent_feature_list, and

    ZTF_fid_to_wavelengths are defined in the global scope.

  • Raises a KeyError if the required columns (‘fid’ or ‘jd’) are missing from the input table.