oracle.custom_datasets.ELAsTiCC module

Custom dataset class for the ELAsTiCC light curve dataset for LSST.

class oracle.custom_datasets.ELAsTiCC.ELAsTiCC_LC_Dataset(parquet_file_path, max_n_per_class=None, include_lc_plots=False, transform=None, excluded_classes=[])

Bases: Dataset

clean_up_dataset()
Cleans and transforms the dataset contained in self.parquet_df by performing a series of operations:
  • Replaces band labels with their corresponding mean wavelengths using a predefined mapping.

  • Removes saturation-affected data points from time-dependent feature series (excluding PHOTFLAG) by:
    • Filtering the data using a bitmask to remove values with saturation (determined by the presence of a specific bit in the photometric flag).

  • Processes the PHOTFLAG series by:
    • Removing data points flagged as saturated.

    • Converting the remaining bitmask values into binary flags (1 for detections based on a specific bit and 0 otherwise).

  • Normalizes the MJD (Modified Julian Date) values by subtracting the time of the first observation, effectively realigning the time series.

  • Replaces missing values in time-independent features with a predetermined flag value.

Returns:

None

Note

  • This method defines two helper functions locally:
    • remove_saturations_from_series: Filters a given series based on the photometric flag to remove saturations.

    • replace_missing_flags: Substitutes missing data flags with a specified flag value.

exclude_classes()

Exclude classes specified in self.excluded_classes from the dataset. This method filters the records in self.parquet_df, removing any rows where the “class” column matches an entry in self.excluded_classes. It concatenates the remaining dataframes per class and updates self.parquet_df with the result.

Returns:

None

Side Effects:
  • Modifies self.parquet_df by excluding rows of unwanted classes.

get_all_labels()

Retrieves all labels from the parquet dataframe’s ‘class’ column.

Returns:

A list of labels extracted from the ‘class’ column.

Return type:

list

get_lc_plots(x_ts)

Generates a light curve plot image from time series data and returns it as a Torch tensor.

Parameters:

x_ts (numpy.ndarray) – 2D array where each row corresponds to a time step and columns represent various features including ‘jd’ (Julian Date), ‘flux’ (observed flux), ‘flux_err’ (flux error), and ‘fid’ (filter identifier). The feature indices are determined using the global variable time_dependent_feature_list.

Returns:

A tensor representing the RGB image of the generated light curve plot with shape (3, H, W), where H and W are the height and width of the image in pixels.

Return type:

torch.Tensor

Note

  • The function uses matplotlib to create the plot and PIL to handle image conversion.

  • It iterates over wavelengths defined in the global dictionary ZTF_wavelength_to_color, plotting error bars for each wavelength filtered by the ‘fid’ feature.

  • The output image is saved to an in-memory buffer at 100 dpi, then converted from a PIL Image to a NumPy array and finally to a Torch tensor.

Warning

  • [TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…

limit_max_samples_per_class()

Limits the number of samples for each class in the DataFrame. This method processes the DataFrame contained in the instance attribute self.parquet_df by:

  • Determining the unique classes present in the “class” column.

  • For each unique class, selecting only the first self.max_n_per_class samples.

  • Concatenating the limited samples from all classes back into self.parquet_df.

It also prints the maximum allowed samples per class and the resulting sample count for each class.

Note

  • Assumes self.parquet_df is a Polars DataFrame.

map_models_to_classes()

Maps ELAsTiCC classes to astrophysical classes by replacing the values in the ‘ELASTICC_class’ column of the DataFrame with the corresponding astrophysical class names. The transformation is done in-place on self.parquet_df by creating a new column named ‘class’.

Returns:

None

Side Effects:
  • Modifies self.parquet_df by adding/updating the ‘class’ column.

oracle.custom_datasets.ELAsTiCC.custom_collate_ELAsTiCC(batch)

Custom collation function for processing a batch of ELAsTiCC dataset samples.

Parameters:

batch (list) – A list of dictionaries, each representing a sample.

Returns:

A dictionary containing the collated batch with the following keys:
  • ’ts’: A padded tensor of time series data with shape (batch_size, max_length, …), where padding is applied using the predefined flag_value.

  • ’static’: A tensor of static features with shape (batch_size, n_static_features).

  • ’length’: A tensor containing the lengths of each time series in the batch.

  • ’label’: A numpy array of labels for the batch (array-like).

  • ’raw_label’: A numpy array of raw ELAsTiCC class labels (array-like).

  • ’id’: A numpy array of SNIDs corresponding to each sample.

  • ’lc_plot’ (if present in the input samples): A tensor of light curve plots with shape (batch_size, n_channels, img_height, img_width).

Return type:

dict

oracle.custom_datasets.ELAsTiCC.show_batch(images, labels, n=16)

Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.

Parameters:
  • images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.

  • labels (Sequence) – Sequence of labels corresponding to each image.

  • n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.

Displays:

A matplotlib figure containing a grid of images, each annotated with its respective label.

oracle.custom_datasets.ELAsTiCC.truncate_ELAsTiCC_light_curve_by_days_since_trigger(x_ts, d=None)

Truncates the light curve data to only include observations within a specified number of days since the first detection.

Parameters:
  • x_ts (np.ndarray) – A 2D array representing the time series light curve data, where each row is an observation and columns correspond to different features.

  • d (float, optional) – The number of days after the first detection to use as the cutoff for truncation. If None, a random value is generated using a uniform distribution over the exponent of 2 in the range [0, 11].

Returns:

The truncated light curve array containing only the observations within ‘d’ days of the first detection.

Return type:

np.ndarray

Note

  • Assumes that the column corresponding to ‘PHOTFLAG’ (indicating detection status) and ‘MJD’ (the modified Julian date) exist in x_ts.

  • The indices for ‘PHOTFLAG’ and ‘MJD’ are obtained using the global list ‘time_dependent_feature_list’.

  • Raises an IndexError if no detection (i.e., a value of 1 in the ‘PHOTFLAG’ column) is found.

oracle.custom_datasets.ELAsTiCC.truncate_ELAsTiCC_light_curve_fractionally(x_ts, f=None)

Truncate an ELAstiCC light curve by a fractional amount. This function reduces the number of observations in the light curve array based on a specified fraction. If no fraction is provided, a random fraction between 0.1 and 1.0 is chosen. The truncation ensures that at least one observation remains.

Parameters:
  • x_ts (numpy.ndarray) – A 2D array representing the light curve data, where each row corresponds to an observation and the columns represent different features.

  • f (float, optional) – A fraction between 0.0 and 1.0 to determine the portion of the light curve to retain. If None, a random fraction in the range [0.1, 1.0] is used.

Returns:

The truncated light curve, containing only the first portion of the observations as determined by the fraction.

Return type:

numpy.ndarray