oracle.custom_datasets.ELAsTiCC module
Custom dataset class for the ELAsTiCC light curve dataset for LSST.
- class oracle.custom_datasets.ELAsTiCC.ELAsTiCC_LC_Dataset(parquet_file_path, max_n_per_class=None, include_lc_plots=False, transform=None, excluded_classes=[])
Bases:
Dataset- clean_up_dataset()
- Cleans and transforms the dataset contained in self.parquet_df by performing a series of operations:
Replaces band labels with their corresponding mean wavelengths using a predefined mapping.
- Removes saturation-affected data points from time-dependent feature series (excluding PHOTFLAG) by:
Filtering the data using a bitmask to remove values with saturation (determined by the presence of a specific bit in the photometric flag).
- Processes the PHOTFLAG series by:
Removing data points flagged as saturated.
Converting the remaining bitmask values into binary flags (1 for detections based on a specific bit and 0 otherwise).
Normalizes the MJD (Modified Julian Date) values by subtracting the time of the first observation, effectively realigning the time series.
Replaces missing values in time-independent features with a predetermined flag value.
- Returns:
None
Note
- This method defines two helper functions locally:
remove_saturations_from_series: Filters a given series based on the photometric flag to remove saturations.
replace_missing_flags: Substitutes missing data flags with a specified flag value.
- exclude_classes()
Exclude classes specified in self.excluded_classes from the dataset. This method filters the records in self.parquet_df, removing any rows where the “class” column matches an entry in self.excluded_classes. It concatenates the remaining dataframes per class and updates self.parquet_df with the result.
- Returns:
None
- Side Effects:
Modifies self.parquet_df by excluding rows of unwanted classes.
- get_all_labels()
Retrieves all labels from the parquet dataframe’s ‘class’ column.
- Returns:
A list of labels extracted from the ‘class’ column.
- Return type:
list
- get_lc_plots(x_ts)
Generates a light curve plot image from time series data and returns it as a Torch tensor.
- Parameters:
x_ts (numpy.ndarray) – 2D array where each row corresponds to a time step and columns represent various features including ‘jd’ (Julian Date), ‘flux’ (observed flux), ‘flux_err’ (flux error), and ‘fid’ (filter identifier). The feature indices are determined using the global variable time_dependent_feature_list.
- Returns:
A tensor representing the RGB image of the generated light curve plot with shape (3, H, W), where H and W are the height and width of the image in pixels.
- Return type:
torch.Tensor
Note
The function uses matplotlib to create the plot and PIL to handle image conversion.
It iterates over wavelengths defined in the global dictionary ZTF_wavelength_to_color, plotting error bars for each wavelength filtered by the ‘fid’ feature.
The output image is saved to an in-memory buffer at 100 dpi, then converted from a PIL Image to a NumPy array and finally to a Torch tensor.
Warning
[TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…
- limit_max_samples_per_class()
Limits the number of samples for each class in the DataFrame. This method processes the DataFrame contained in the instance attribute self.parquet_df by:
Determining the unique classes present in the “class” column.
For each unique class, selecting only the first self.max_n_per_class samples.
Concatenating the limited samples from all classes back into self.parquet_df.
It also prints the maximum allowed samples per class and the resulting sample count for each class.
Note
Assumes self.parquet_df is a Polars DataFrame.
- map_models_to_classes()
Maps ELAsTiCC classes to astrophysical classes by replacing the values in the ‘ELASTICC_class’ column of the DataFrame with the corresponding astrophysical class names. The transformation is done in-place on self.parquet_df by creating a new column named ‘class’.
- Returns:
None
- Side Effects:
Modifies self.parquet_df by adding/updating the ‘class’ column.
- oracle.custom_datasets.ELAsTiCC.custom_collate_ELAsTiCC(batch)
Custom collation function for processing a batch of ELAsTiCC dataset samples.
- Parameters:
batch (list) – A list of dictionaries, each representing a sample.
- Returns:
- A dictionary containing the collated batch with the following keys:
’ts’: A padded tensor of time series data with shape (batch_size, max_length, …), where padding is applied using the predefined flag_value.
’static’: A tensor of static features with shape (batch_size, n_static_features).
’length’: A tensor containing the lengths of each time series in the batch.
’label’: A numpy array of labels for the batch (array-like).
’raw_label’: A numpy array of raw ELAsTiCC class labels (array-like).
’id’: A numpy array of SNIDs corresponding to each sample.
’lc_plot’ (if present in the input samples): A tensor of light curve plots with shape (batch_size, n_channels, img_height, img_width).
- Return type:
dict
- oracle.custom_datasets.ELAsTiCC.show_batch(images, labels, n=16)
Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.
- Parameters:
images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.
labels (Sequence) – Sequence of labels corresponding to each image.
n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.
- Displays:
A matplotlib figure containing a grid of images, each annotated with its respective label.
- oracle.custom_datasets.ELAsTiCC.truncate_ELAsTiCC_light_curve_by_days_since_trigger(x_ts, d=None)
Truncates the light curve data to only include observations within a specified number of days since the first detection.
- Parameters:
x_ts (np.ndarray) – A 2D array representing the time series light curve data, where each row is an observation and columns correspond to different features.
d (float, optional) – The number of days after the first detection to use as the cutoff for truncation. If None, a random value is generated using a uniform distribution over the exponent of 2 in the range [0, 11].
- Returns:
The truncated light curve array containing only the observations within ‘d’ days of the first detection.
- Return type:
np.ndarray
Note
Assumes that the column corresponding to ‘PHOTFLAG’ (indicating detection status) and ‘MJD’ (the modified Julian date) exist in x_ts.
The indices for ‘PHOTFLAG’ and ‘MJD’ are obtained using the global list ‘time_dependent_feature_list’.
Raises an IndexError if no detection (i.e., a value of 1 in the ‘PHOTFLAG’ column) is found.
- oracle.custom_datasets.ELAsTiCC.truncate_ELAsTiCC_light_curve_fractionally(x_ts, f=None)
Truncate an ELAstiCC light curve by a fractional amount. This function reduces the number of observations in the light curve array based on a specified fraction. If no fraction is provided, a random fraction between 0.1 and 1.0 is chosen. The truncation ensures that at least one observation remains.
- Parameters:
x_ts (numpy.ndarray) – A 2D array representing the light curve data, where each row corresponds to an observation and the columns represent different features.
f (float, optional) – A fraction between 0.0 and 1.0 to determine the portion of the light curve to retain. If None, a random fraction in the range [0.1, 1.0] is used.
- Returns:
The truncated light curve, containing only the first portion of the observations as determined by the fraction.
- Return type:
numpy.ndarray