oracle.custom_datasets package

Submodules

oracle.custom_datasets.BTS module

Custom dataset class for the ZTF Bright Transient Survey light curve dataset.

class oracle.custom_datasets.BTS.BTS_LC_Dataset(parquet_file_path, mapper=None, max_n_per_class=None, include_postage_stamps=False, include_lc_plots=False, transform=None, over_sample=False, excluded_classes=[])

Bases: Dataset

A custom PyTorch Dataset class for handling BTS light curve data stored in a parquet file.

clean_up_dataset()
Clean up the dataset by executing a series of transformations:
  • Adjusts the observation times by subtracting the time of the first observation from each ‘jd’ entry.

  • Converts band labels in the ‘fid’ column to their corresponding mean wavelengths based on a predefined mapping.

  • Maps BTS sample explorer classes from the ‘bts_class’ column to astrophysical classes using a provided mapper.

  • Computes new WISE color indices:
    • ‘W1_minus_W3’ as the difference between ‘W1mag’ and ‘W3mag’.

    • ‘W2_minus_W3’ as the difference between ‘W2mag’ and ‘W3mag’.

  • Transforms celestial coordinates:
    • Converts the right ascension (‘ra’) and declination (‘dec’) into galactic longitude (‘l’).

    • Converts the right ascension (‘ra’) and declination (‘dec’) into galactic latitude (‘b’) using the SkyCoord library.

Prints status messages at each step to indicate progress.

convert_mags_to_flux()

Converts magnitude measurements to flux values and computes the corresponding flux uncertainties. The method processes the dataframe stored in self.parquet_df by adding two new columns:

  • “flux”: Computed from the “magpsf” column using the relation: flux = F0 * 10^(-0.4 * magpsf)

    where F0 is set to 3631.0 * 1e6 micro-Janskys (µJy).

  • “flux_err”: Calculated via error propagation using both “magpsf” and “sigmapsf” columns:

    flux_err = (0.4 * ln(10)) * F0 * 10^(-0.4 * magpsf) * sigmapsf

Note

  • Element-wise operations are employed to map the magnitude values (and their uncertainties)to corresponding flux values across the dataframe.

  • The resulting “flux” and “flux_err” columns are lists of floats.

Returns:

None. The dataframe self.parquet_df is modified in place.

exclude_classes()

Exclude specified classes from the dataset. This method filters the dataset contained in self.parquet_df by removing any rows whose ‘class’ value is present in the self.excluded_classes list.

Returns:

None

Note

  • Assumes that the dataframe self.parquet_df has a column named ‘class’.

  • The method leverages the filtering function of the dataframe library (e.g., Polars).

get_all_labels()

Retrieves all labels from the parquet dataframe’s ‘class’ column. :returns: A list of labels extracted from the ‘class’ column. :rtype: list

get_lc_plots(x_ts)

Generates a light curve plot image from time series data and returns it as a Torch tensor.

Parameters:

x_ts (numpy.ndarray) – 2D array where each row corresponds to a time step and columns represent various features including ‘jd’ (Julian Date), ‘flux’ (observed flux), ‘flux_err’ (flux error), and ‘fid’ (filter identifier). The feature indices are determined using the global variable time_dependent_feature_list.

Returns:

A tensor representing the RGB image of the generated light curve plot with shape (3, H, W), where H and W are the height and width of the image in pixels.

Return type:

torch.Tensor

Note

  • The function uses matplotlib to create the plot and PIL to handle image conversion.

  • It iterates over wavelengths defined in the global dictionary ZTF_wavelength_to_color, plotting error bars for each wavelength filtered by the ‘fid’ feature.

  • The output image is saved to an in-memory buffer at 100 dpi, then converted from a PIL Image to a NumPy array and finally to a Torch tensor.

Warning

  • [TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…

get_postage_stamp(examples)

Generates a postage stamp image tensor for the provided source. This function iterates over a set of predefined filters (ztf_filters) and extracts the corresponding reference images from the ‘examples’ dictionary. Each extracted image is stored in a canvas array, which is then scaled from a range [0, 1] to [0, 255] and converted into a PyTorch tensor.

Parameters:

examples (dict) – A dictionary containing image data corresponding to each filter reference, with keys formatted as ‘{filter}_reference’.

Returns:

A tensor of shape (len(ztf_filters), img_length, img_length) representing the postage stamp images, with pixel values scaled to the range [0, 255].

Return type:

torch.Tensor

get_postage_stamp_plot(examples)

Generates a postage stamp plot from the given image examples and returns it as a PyTorch tensor. The function assembles a canvas by arranging image sections corresponding to various filters (defined in the global ztf_filters) side-by-side. Each section is populated using the “{filter}_reference” key from the provided examples dictionary. The canvas is then plotted using matplotlib, without axis ticks or spines, and saved into a PNG image buffer. This image is subsequently loaded via PIL, converted to a NumPy array with channel-first ordering, and finally wrapped into a PyTorch tensor.

Parameters:

examples (dict) – A dictionary containing image data. For every filter in the global variable ztf_filters, the key corresponding to the image is expected to be formatted as “{filter}_reference”.

Returns:

The resulting image as a PyTorch tensor with shape (C, H, W), where C is the number of color channels, and H and W correspond to the height and width of the assembled image, respectively.

Return type:

torch.Tensor

Warning

  • [TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…

limit_max_samples_per_class()

Limits the number of samples for each class in the dataset to a maximum specified by self.max_n_per_class. This method performs the following steps:

  • Prints a message indicating that the dataset is being limited to self.max_n_per_class samples per class.

  • Retrieves the unique class labels from the ‘class’ column of self.parquet_df.

  • For each unique class, it filters the dataframe to include only entries belonging to that class and slices the result to retain only the first self.max_n_per_class rows.

  • Collects the limited dataframes for each class and concatenates them back into a single dataframe.

  • Prints the number of samples retained for each class.

  • Updates self.parquet_df with the concatenated, limited dataframe.

Note

  • Assumes self.parquet_df is a Polars DataFrame.

  • Uses NumPy to determine unique class labels.

over_sample_minority_classes()

Oversamples the minority classes in the dataset contained in self.parquet_df. This method identifies the class with the maximum sample count and for every other class, performs sampling with replacement so that each class has the same number of samples as the majority class. The oversampled dataframes for the minority classes are concatenated and reassigned to self.parquet_df.

Side Effects:
  • Modifies self.parquet_df by oversampling minority classes.

print_dataset_composition()

Prints a summary of the dataset composition before any transformations or mappings are applied. This method extracts the unique classes and their respective counts from the ‘bts_class’ column of the dataframe stored in self.parquet_df, constructs a dictionary with this information, and then prints it as a formatted table.

oracle.custom_datasets.BTS.custom_collate_BTS(batch)

Collate function for batching BTS dataset samples. This function takes a list of sample dictionaries and collates them into a single batch dictionary suitable for training or inference. It pads the time-series data, concatenates static and meta features, and properly stacks optional image data (postage_stamp and lc_plot) if they are present in the sample dictionaries. :param batch: A list where each element is a dictionary containing the following keys:

Required keys:
  • ‘ts’ (numpy.ndarray): Time-series data for the sample.

  • ‘label’: Label corresponding to the sample.

  • ‘ZTFID’: Identifier for the sample.

  • ‘bts_class’: Raw label or class of the sample.

  • ‘meta’ (numpy.ndarray): Meta features array from which the last row is used.

  • ‘static’ (numpy.ndarray): Static features for the sample.

Optional keys:
  • ‘postage_stamp’: Tensor representing the postage stamp image.

  • ‘lc_plot’: Tensor representing the light curve plot image.

Returns:

A dictionary with the following entries:
  • ’ts’ (Tensor): Padded time-series data tensor with shape (batch_size, max_seq_length, …).

  • ’static’ (Tensor): Concatenated tensor of static and meta features.

  • ’length’ (Tensor): Tensor containing the original lengths of each time-series sample.

  • ’label’ (numpy.ndarray): Array of sample labels.

  • ’raw_label’ (numpy.ndarray): Array of BTS raw labels.

  • ’id’ (numpy.ndarray): Array of sample identifiers.

  • ’postage_stamp’ (Tensor, optional): Stacked tensor of postage stamp images, if available.

  • ’lc_plot’ (Tensor, optional): Stacked tensor of light curve plot images, if available.

Return type:

dict

oracle.custom_datasets.BTS.show_batch(images, labels, n=16)

Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.

Parameters:
  • images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.

  • labels (Sequence) – Sequence of labels corresponding to each image.

  • n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.

Displays:

A matplotlib figure containing a grid of images, each annotated with its respective label.

oracle.custom_datasets.BTS.truncate_BTS_light_curve_by_days_since_trigger(x_ts, x_static, d=None, add_jitter=False, normalize_flux=False)

Truncate the BTS light curve based on the number of days since the first trigger. This function selects observations from the time-series data (x_ts) that occur within a specified number of days (d) from the first detection (trigger). It also optionally adds jitter to the flux measurements and normalizes the flux values.

Parameters:
  • x_ts (numpy.ndarray) –

    Time-dependent features array where each row represents an observation. Expected to contain the following features:

    • ’jd’: Julian date of the observation.

    • ’magpsf’: magnitude.

    • ’sigmapsf’: Uncertainty of the magnitude.

    • ’fid’: Filter identifier.

    • ’photflag’: Photometric flag (added as an extra column to maintain compatibility with ZTF sims and always assumed to be 1).

  • x_static (numpy.ndarray) – Corresponding static features array for each observation.

  • d (float, optional) – Maximum number of days from the trigger within which to keep observations. If None, a random threshold is generated using 2^(uniform(0, 11)).

  • add_jitter (bool, optional) – If True, adds Gaussian noise to the flux by using flux error as 1 sigma values.

  • normalize_flux (bool, optional) – If True, normalizes the flux (‘magpsf’) using its mean and standard deviation.

Returns:

A tuple containing:
  • x_ts (numpy.ndarray): The truncated time-dependent features array.

  • x_static (numpy.ndarray): The static features array.

Return type:

tuple

Note

  • The function assumes that the dataset does not contain any non-detections.

oracle.custom_datasets.ELAsTiCC module

Custom dataset class for the ELAsTiCC light curve dataset for LSST.

class oracle.custom_datasets.ELAsTiCC.ELAsTiCC_LC_Dataset(parquet_file_path, max_n_per_class=None, include_lc_plots=False, transform=None, excluded_classes=[])

Bases: Dataset

clean_up_dataset()
Cleans and transforms the dataset contained in self.parquet_df by performing a series of operations:
  • Replaces band labels with their corresponding mean wavelengths using a predefined mapping.

  • Removes saturation-affected data points from time-dependent feature series (excluding PHOTFLAG) by:
    • Filtering the data using a bitmask to remove values with saturation (determined by the presence of a specific bit in the photometric flag).

  • Processes the PHOTFLAG series by:
    • Removing data points flagged as saturated.

    • Converting the remaining bitmask values into binary flags (1 for detections based on a specific bit and 0 otherwise).

  • Normalizes the MJD (Modified Julian Date) values by subtracting the time of the first observation, effectively realigning the time series.

  • Replaces missing values in time-independent features with a predetermined flag value.

Returns:

None

Note

  • This method defines two helper functions locally:
    • remove_saturations_from_series: Filters a given series based on the photometric flag to remove saturations.

    • replace_missing_flags: Substitutes missing data flags with a specified flag value.

exclude_classes()

Exclude classes specified in self.excluded_classes from the dataset. This method filters the records in self.parquet_df, removing any rows where the “class” column matches an entry in self.excluded_classes. It concatenates the remaining dataframes per class and updates self.parquet_df with the result.

Returns:

None

Side Effects:
  • Modifies self.parquet_df by excluding rows of unwanted classes.

get_all_labels()

Retrieves all labels from the parquet dataframe’s ‘class’ column.

Returns:

A list of labels extracted from the ‘class’ column.

Return type:

list

get_lc_plots(x_ts)

Generates a light curve plot image from time series data and returns it as a Torch tensor.

Parameters:

x_ts (numpy.ndarray) – 2D array where each row corresponds to a time step and columns represent various features including ‘jd’ (Julian Date), ‘flux’ (observed flux), ‘flux_err’ (flux error), and ‘fid’ (filter identifier). The feature indices are determined using the global variable time_dependent_feature_list.

Returns:

A tensor representing the RGB image of the generated light curve plot with shape (3, H, W), where H and W are the height and width of the image in pixels.

Return type:

torch.Tensor

Note

  • The function uses matplotlib to create the plot and PIL to handle image conversion.

  • It iterates over wavelengths defined in the global dictionary ZTF_wavelength_to_color, plotting error bars for each wavelength filtered by the ‘fid’ feature.

  • The output image is saved to an in-memory buffer at 100 dpi, then converted from a PIL Image to a NumPy array and finally to a Torch tensor.

Warning

  • [TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…

limit_max_samples_per_class()

Limits the number of samples for each class in the DataFrame. This method processes the DataFrame contained in the instance attribute self.parquet_df by:

  • Determining the unique classes present in the “class” column.

  • For each unique class, selecting only the first self.max_n_per_class samples.

  • Concatenating the limited samples from all classes back into self.parquet_df.

It also prints the maximum allowed samples per class and the resulting sample count for each class.

Note

  • Assumes self.parquet_df is a Polars DataFrame.

map_models_to_classes()

Maps ELAsTiCC classes to astrophysical classes by replacing the values in the ‘ELASTICC_class’ column of the DataFrame with the corresponding astrophysical class names. The transformation is done in-place on self.parquet_df by creating a new column named ‘class’.

Returns:

None

Side Effects:
  • Modifies self.parquet_df by adding/updating the ‘class’ column.

oracle.custom_datasets.ELAsTiCC.custom_collate_ELAsTiCC(batch)

Custom collation function for processing a batch of ELAsTiCC dataset samples.

Parameters:

batch (list) – A list of dictionaries, each representing a sample.

Returns:

A dictionary containing the collated batch with the following keys:
  • ’ts’: A padded tensor of time series data with shape (batch_size, max_length, …), where padding is applied using the predefined flag_value.

  • ’static’: A tensor of static features with shape (batch_size, n_static_features).

  • ’length’: A tensor containing the lengths of each time series in the batch.

  • ’label’: A numpy array of labels for the batch (array-like).

  • ’raw_label’: A numpy array of raw ELAsTiCC class labels (array-like).

  • ’id’: A numpy array of SNIDs corresponding to each sample.

  • ’lc_plot’ (if present in the input samples): A tensor of light curve plots with shape (batch_size, n_channels, img_height, img_width).

Return type:

dict

oracle.custom_datasets.ELAsTiCC.show_batch(images, labels, n=16)

Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.

Parameters:
  • images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.

  • labels (Sequence) – Sequence of labels corresponding to each image.

  • n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.

Displays:

A matplotlib figure containing a grid of images, each annotated with its respective label.

oracle.custom_datasets.ELAsTiCC.truncate_ELAsTiCC_light_curve_by_days_since_trigger(x_ts, d=None)

Truncates the light curve data to only include observations within a specified number of days since the first detection.

Parameters:
  • x_ts (np.ndarray) – A 2D array representing the time series light curve data, where each row is an observation and columns correspond to different features.

  • d (float, optional) – The number of days after the first detection to use as the cutoff for truncation. If None, a random value is generated using a uniform distribution over the exponent of 2 in the range [0, 11].

Returns:

The truncated light curve array containing only the observations within ‘d’ days of the first detection.

Return type:

np.ndarray

Note

  • Assumes that the column corresponding to ‘PHOTFLAG’ (indicating detection status) and ‘MJD’ (the modified Julian date) exist in x_ts.

  • The indices for ‘PHOTFLAG’ and ‘MJD’ are obtained using the global list ‘time_dependent_feature_list’.

  • Raises an IndexError if no detection (i.e., a value of 1 in the ‘PHOTFLAG’ column) is found.

oracle.custom_datasets.ELAsTiCC.truncate_ELAsTiCC_light_curve_fractionally(x_ts, f=None)

Truncate an ELAstiCC light curve by a fractional amount. This function reduces the number of observations in the light curve array based on a specified fraction. If no fraction is provided, a random fraction between 0.1 and 1.0 is chosen. The truncation ensures that at least one observation remains.

Parameters:
  • x_ts (numpy.ndarray) – A 2D array representing the light curve data, where each row corresponds to an observation and the columns represent different features.

  • f (float, optional) – A fraction between 0.0 and 1.0 to determine the portion of the light curve to retain. If None, a random fraction in the range [0.1, 1.0] is used.

Returns:

The truncated light curve, containing only the first portion of the observations as determined by the fraction.

Return type:

numpy.ndarray

oracle.custom_datasets.ZTF_sims module

Custom dataset class for the simulated ZTF light curve dataset.

class oracle.custom_datasets.ZTF_sims.ZTF_SIM_LC_Dataset(parquet_file_path, max_n_per_class=None, include_lc_plots=False, transform=None)

Bases: Dataset

clean_up_dataset()

Cleans and transforms the dataset stored in the object’s parquet_df attribute. This method performs several dataset cleaning operations including: - Replacing band labels in the “FLT” column with their corresponding mean wavelengths

using the mapping defined in ZTF_passband_to_wavelengths.

  • Removing saturated measurements from time series features:
    • For each time-dependent feature, it removes data points where the corresponding “PHOTFLAG” bitmask indicates saturation (using bitwise logic with 1024).

  • Replacing the “PHOTFLAG” bitmask values:
    • Converts the cleaned “PHOTFLAG” list to binary flags where any instance of the flag 4096 is set to 1 (indicating detection) and 0 otherwise.

  • Adjusting the time series:
    • Normalizes the “MJD_clean” column by subtracting the time of the first observation from all entries.

  • Mapping simulation class labels:
    • Replaces ZTF simulation classes with astrophysical class labels using ZTF_sims_to_Astrophysical_mappings.

  • Handling missing data in time-independent features:
    • Replaces any feature value that matches a missing data flag with a specified flag_value.

The method updates the parquet_df in place and prints progress messages for each step.

get_all_labels()

Retrieve all labels from the dataset.

This method extracts the ‘class’ column from the parquet dataframe attribute and returns it as a list.

Returns:

A list containing all labels present in the ‘class’ column.

Return type:

list

get_lc_plots(x_ts)

Generates a light curve plot from the provided time series data and returns it as a PyTorch tensor.

The function performs the following steps:
  • Extracts light curve parameters such as Julian dates, flux measurements, flux errors, filter identifiers, and photometric flags from the input array.

  • For each wavelength (as defined in ‘ZTF_wavelength_to_color’), it plots:

  • Detected data points using the ‘marker_style_detection’.

  • Non-detected data points using the ‘marker_style_non_detection’. Both with error bars corresponding to flux uncertainties.

  • Overlays a line plot connecting all points for each wavelength.

  • Configures the matplotlib figure by setting a fixed size, removing axis ticks and all spines.

  • Saves the plot into an in-memory PNG buffer at a specified DPI, then loads it via PIL.

  • Converts the image to a NumPy array, permutes the dimensions, and finally converts it into a PyTorch tensor.

Parameters:

x_ts (np.ndarray) – A 2D NumPy array representing the time series data.

Returns:

A tensor representation of the generated light curve plot image.

Return type:

torch.Tensor

limit_max_samples_per_class()

Limits the number of samples per class in the dataset.

For every unique class in the ‘class’ column of self.parquet_df, this method selects at most self.max_n_per_class entries (i.e., the first self.max_n_per_class samples) and concatenates them into a new dataframe that replaces self.parquet_df.

Informative messages are printed to indicate:
  • The overall limit being applied per class.

  • The resulting number of samples retained for each class.

print_dataset_composition()

Prints the composition of the dataset. It formats these values into a pandas DataFrame and prints the resulting table to the console.

Returns:

None

oracle.custom_datasets.ZTF_sims.custom_collate_ZTF_SIM(batch)

Collates a batch of ZTF simulation samples into a single dictionary of tensors suitable for model input.

Each sample in the input batch is expected to be a dictionary with the following keys:
  • ‘ts’: A tensor representing the time series data. The tensor should have shape (num_time_points, …).

  • ‘label’: The label associated with the sample (e.g., class index or regression target).

  • ‘SNID’: A unique identifier for the sample.

  • ‘static’: A tensor of static features with a predefined number of features (n_static_features).

  • ‘lc_plot’: (Optional) A tensor representing the light curve plot image with shape (n_channels, img_height, img_width).

Parameters:

batch (list) – A list of sample dictionaries as described above.

Returns:

A dictionary containing the collated batch with the same keys as the input samples, where:

Return type:

dict

oracle.custom_datasets.ZTF_sims.show_batch(images, labels, n=16)

Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.

Parameters:
  • images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.

  • labels (Sequence) – Sequence of labels corresponding to each image.

  • n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.

Displays:

A matplotlib figure containing a grid of images, each annotated with its respective label.

oracle.custom_datasets.ZTF_sims.truncate_ZTF_SIM_light_curve_by_days_since_trigger(x_ts, d)

Truncates a ZTF SIM light curve by retaining only the observations within a specified number of days since the trigger (first detection).

Parameters:
  • x_ts (numpy.ndarray) – A 2D array representing the time series data of the light curve. It is expected to have columns corresponding to various features, including ‘PHOTFLAG’ and ‘MJD’, as specified in the global list ‘time_dependent_feature_list’.

  • d (float) – The time window (in days) from the first detection. Observations beyond this period are removed.

Returns:

The truncated light curve array containing only the observations within the specified days since the trigger.

Return type:

numpy.ndarray

oracle.custom_datasets.ZTF_sims.truncate_ZTF_SIM_light_curve_fractionally(x_ts, f=None)

Truncate a ZTF simulation light curve by retaining only a fraction of its observations.

Parameters:
  • x_ts (numpy.ndarray) – A 2D array representing the light curve where each row corresponds to an observation.

  • f (float, optional) – Fraction of the total observations to retain. If not provided (None), a random fraction between 0.1 and 1.0 will be used.

Returns:

A truncated version of the input light curve containing a fraction (at least one) of the original observations.

Return type:

numpy.ndarray

Module contents

Custom datasets for the ORACLE framework.