oracle.pretrained.BTS module
Pretrained model(s) for the BTS dataset.
- class oracle.pretrained.BTS.ORACLE1_BTS(model_dir='/Users/vedshah/Documents/Research/NU-Miller/Projects/Hierarchical-VT/models/BTS/lemon-spaceship-252')
Bases:
GRU_MDORACLE1_BTS is a model class that inherits from GRU_MD designed to load a pre-trained BTS model and perform predictions on time series data augmented with static features. The model uses a hierarchical taxonomy to output predictions at multiple levels of granularity.
- taxonomy
An instance of the taxonomy used to structure the class labels.
- Type:
- ts_feature_dim
Dimensionality of the time series input features.
- Type:
int
- static_feature_dim
Dimensionality of the static input features.
- Type:
int
- model_dir
Directory path where the model weights are stored.
- Type:
str
- embed(table)
Embed a table into its latent space representation.
- Parameters:
table – The input data (e.g., a table or structured data) to be embedded. The exact format is expected to be compatible with the make_batch method.
- Returns:
A NumPy array containing the latent space embeddings corresponding to the input table.
- Return type:
numpy.ndarray
- make_batch(table)
Create a batch from the input table.
- Parameters:
table (astropy.table.Table) – Input data containing one or more rows.
- Returns:
A dictionary containing the batch data.
- Return type:
dict
- predict(table)
Predict the label at each hierarchical level for the table.
- Parameters:
table (astropy.table.Table) – Input data containing one or more rows.
- Returns:
Mapping from hierarchical level (as returned by self.score) to the predicted class label. For each level, self.score(table) is expected to return a pandas.DataFrame of shape (n_samples, n_classes) with class labels as columns; the predicted label is the column with the highest score for the first sample.
- Return type:
dict
- Raises:
Any exceptions raised by self.score or by numpy operations (e.g., if the score DataFrame is empty) will be propagated. –
- predict_full_scores(table)
Predict class probability scores for a single time-series table.
Prepares a single observation table for the model by calling augment_table and converting time-dependent and time-independent features into torch tensors. The input table must contain the columns ‘magpsf’, ‘sigmpdf’, ‘fid’, ‘jd’, and ‘photflag’.
- Parameters:
table (astropy.table.Table) – Astropy Table containing time-series data.
- Returns:
A DataFrame containing class probability scores for each class in the taxonomy.
- Return type:
pd.DataFrame
- Raises:
ValueError – If time-series columns have inconsistent lengths or if the table is empty in a way that the downstream model cannot handle.
- score(table)
Compute hierarchical scores for the input table. Predicts scores for all taxonomy nodes using self.predict_full_scores(), then groups those scores by taxonomy depth and returns a mapping from depth levels to DataFrames containing the corresponding node scores. Taxonomy levels 0 are removed because they are irrelevant in the current taxonomy.
- Parameters:
table (astropy.table.Table) – Input observations/features to be scored.
- Returns:
A mapping from taxonomy depth level to a DataFrame of predicted scores for nodes at that level. Each DataFrame is a subset of the full prediction DataFrame containing only the columns for the nodes at that depth.
- Return type:
dict[int, pandas.DataFrame]
- Raises:
KeyError – If expected node columns (from self.taxonomy.get_nodes_by_depth()) are not present in the DataFrame returned by predict_full_scores().
Note
This is very similar to the model used for the original ORACLE paper.
- oracle.pretrained.BTS.augment_table(table)
Augments a table by modifying its time-related and feature-specific values, and splitting it into two separate tables. This function performs the following modifications:
Converts filter IDs (‘fid’) to their corresponding mean wavelengths using the mapping ZTF_fid_to_wavelengths.
Normalizes the ‘jd’ column by subtracting the minimum
jdvalue to set the starting time at zero.- Reorders the columns based on a predefined list
time_dependent_feature_listto create a time-dependent table, and adds a constant column ‘flag’ with a value of 1.
- Reorders the columns based on a predefined list
Extracts a time-independent table based on a predefined list
time_independent_feature_list.
- Parameters:
table (pandas.DataFrame) – The input table containing astronomical observations with at least ‘fid’ and ‘jd’ columns.
- Returns:
- A tuple containing two pandas DataFrames:
lc_table: The time-dependent table with re-ordered columns and an additional ‘flag’ column.
static_table: The table containing the time-independent features.
- Return type:
tuple
Notes
- The function assumes that the variables
time_dependent_feature_list,time_independent_feature_list, and ZTF_fid_to_wavelengthsare defined in the global scope.
- The function assumes that the variables
Raises a KeyError if the required columns (‘fid’ or ‘jd’) are missing from the input table.