oracle.custom_datasets.BTS module
Custom dataset class for the ZTF Bright Transient Survey light curve dataset.
- class oracle.custom_datasets.BTS.BTS_LC_Dataset(parquet_file_path, mapper=None, max_n_per_class=None, include_postage_stamps=False, include_lc_plots=False, transform=None, over_sample=False, excluded_classes=[])
Bases:
DatasetA custom PyTorch Dataset class for handling BTS light curve data stored in a parquet file.
- clean_up_dataset()
- Clean up the dataset by executing a series of transformations:
Adjusts the observation times by subtracting the time of the first observation from each ‘jd’ entry.
Converts band labels in the ‘fid’ column to their corresponding mean wavelengths based on a predefined mapping.
Maps BTS sample explorer classes from the ‘bts_class’ column to astrophysical classes using a provided mapper.
- Computes new WISE color indices:
‘W1_minus_W3’ as the difference between ‘W1mag’ and ‘W3mag’.
‘W2_minus_W3’ as the difference between ‘W2mag’ and ‘W3mag’.
- Transforms celestial coordinates:
Converts the right ascension (‘ra’) and declination (‘dec’) into galactic longitude (‘l’).
Converts the right ascension (‘ra’) and declination (‘dec’) into galactic latitude (‘b’) using the SkyCoord library.
Prints status messages at each step to indicate progress.
- convert_mags_to_flux()
Converts magnitude measurements to flux values and computes the corresponding flux uncertainties. The method processes the dataframe stored in self.parquet_df by adding two new columns:
- “flux”: Computed from the “magpsf” column using the relation: flux = F0 * 10^(-0.4 * magpsf)
where F0 is set to 3631.0 * 1e6 micro-Janskys (µJy).
- “flux_err”: Calculated via error propagation using both “magpsf” and “sigmapsf” columns:
flux_err = (0.4 * ln(10)) * F0 * 10^(-0.4 * magpsf) * sigmapsf
Note
Element-wise operations are employed to map the magnitude values (and their uncertainties)to corresponding flux values across the dataframe.
The resulting “flux” and “flux_err” columns are lists of floats.
- Returns:
None. The dataframe self.parquet_df is modified in place.
- exclude_classes()
Exclude specified classes from the dataset. This method filters the dataset contained in self.parquet_df by removing any rows whose ‘class’ value is present in the self.excluded_classes list.
- Returns:
None
Note
Assumes that the dataframe self.parquet_df has a column named ‘class’.
The method leverages the filtering function of the dataframe library (e.g., Polars).
- get_all_labels()
Retrieves all labels from the parquet dataframe’s ‘class’ column. :returns: A list of labels extracted from the ‘class’ column. :rtype: list
- get_lc_plots(x_ts)
Generates a light curve plot image from time series data and returns it as a Torch tensor.
- Parameters:
x_ts (numpy.ndarray) – 2D array where each row corresponds to a time step and columns represent various features including ‘jd’ (Julian Date), ‘flux’ (observed flux), ‘flux_err’ (flux error), and ‘fid’ (filter identifier). The feature indices are determined using the global variable time_dependent_feature_list.
- Returns:
A tensor representing the RGB image of the generated light curve plot with shape (3, H, W), where H and W are the height and width of the image in pixels.
- Return type:
torch.Tensor
Note
The function uses matplotlib to create the plot and PIL to handle image conversion.
It iterates over wavelengths defined in the global dictionary ZTF_wavelength_to_color, plotting error bars for each wavelength filtered by the ‘fid’ feature.
The output image is saved to an in-memory buffer at 100 dpi, then converted from a PIL Image to a NumPy array and finally to a Torch tensor.
Warning
[TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…
- get_postage_stamp(examples)
Generates a postage stamp image tensor for the provided source. This function iterates over a set of predefined filters (ztf_filters) and extracts the corresponding reference images from the ‘examples’ dictionary. Each extracted image is stored in a canvas array, which is then scaled from a range [0, 1] to [0, 255] and converted into a PyTorch tensor.
- Parameters:
examples (dict) – A dictionary containing image data corresponding to each filter reference, with keys formatted as ‘{filter}_reference’.
- Returns:
A tensor of shape (len(ztf_filters), img_length, img_length) representing the postage stamp images, with pixel values scaled to the range [0, 255].
- Return type:
torch.Tensor
- get_postage_stamp_plot(examples)
Generates a postage stamp plot from the given image examples and returns it as a PyTorch tensor. The function assembles a canvas by arranging image sections corresponding to various filters (defined in the global ztf_filters) side-by-side. Each section is populated using the “{filter}_reference” key from the provided examples dictionary. The canvas is then plotted using matplotlib, without axis ticks or spines, and saved into a PNG image buffer. This image is subsequently loaded via PIL, converted to a NumPy array with channel-first ordering, and finally wrapped into a PyTorch tensor.
- Parameters:
examples (dict) – A dictionary containing image data. For every filter in the global variable ztf_filters, the key corresponding to the image is expected to be formatted as “{filter}_reference”.
- Returns:
The resulting image as a PyTorch tensor with shape (C, H, W), where C is the number of color channels, and H and W correspond to the height and width of the assembled image, respectively.
- Return type:
torch.Tensor
Warning
[TODO] This function can be optimized further to avoid using matplotlib and PIL altogether. Its really slow right now…
- limit_max_samples_per_class()
Limits the number of samples for each class in the dataset to a maximum specified by self.max_n_per_class. This method performs the following steps:
Prints a message indicating that the dataset is being limited to self.max_n_per_class samples per class.
Retrieves the unique class labels from the ‘class’ column of self.parquet_df.
For each unique class, it filters the dataframe to include only entries belonging to that class and slices the result to retain only the first self.max_n_per_class rows.
Collects the limited dataframes for each class and concatenates them back into a single dataframe.
Prints the number of samples retained for each class.
Updates self.parquet_df with the concatenated, limited dataframe.
Note
Assumes self.parquet_df is a Polars DataFrame.
Uses NumPy to determine unique class labels.
- over_sample_minority_classes()
Oversamples the minority classes in the dataset contained in self.parquet_df. This method identifies the class with the maximum sample count and for every other class, performs sampling with replacement so that each class has the same number of samples as the majority class. The oversampled dataframes for the minority classes are concatenated and reassigned to self.parquet_df.
- Side Effects:
Modifies self.parquet_df by oversampling minority classes.
- print_dataset_composition()
Prints a summary of the dataset composition before any transformations or mappings are applied. This method extracts the unique classes and their respective counts from the ‘bts_class’ column of the dataframe stored in self.parquet_df, constructs a dictionary with this information, and then prints it as a formatted table.
- oracle.custom_datasets.BTS.custom_collate_BTS(batch)
Collate function for batching BTS dataset samples. This function takes a list of sample dictionaries and collates them into a single batch dictionary suitable for training or inference. It pads the time-series data, concatenates static and meta features, and properly stacks optional image data (postage_stamp and lc_plot) if they are present in the sample dictionaries. :param batch: A list where each element is a dictionary containing the following keys:
- Required keys:
‘ts’ (numpy.ndarray): Time-series data for the sample.
‘label’: Label corresponding to the sample.
‘ZTFID’: Identifier for the sample.
‘bts_class’: Raw label or class of the sample.
‘meta’ (numpy.ndarray): Meta features array from which the last row is used.
‘static’ (numpy.ndarray): Static features for the sample.
- Optional keys:
‘postage_stamp’: Tensor representing the postage stamp image.
‘lc_plot’: Tensor representing the light curve plot image.
- Returns:
- A dictionary with the following entries:
’ts’ (Tensor): Padded time-series data tensor with shape (batch_size, max_seq_length, …).
’static’ (Tensor): Concatenated tensor of static and meta features.
’length’ (Tensor): Tensor containing the original lengths of each time-series sample.
’label’ (numpy.ndarray): Array of sample labels.
’raw_label’ (numpy.ndarray): Array of BTS raw labels.
’id’ (numpy.ndarray): Array of sample identifiers.
’postage_stamp’ (Tensor, optional): Stacked tensor of postage stamp images, if available.
’lc_plot’ (Tensor, optional): Stacked tensor of light curve plot images, if available.
- Return type:
dict
- oracle.custom_datasets.BTS.show_batch(images, labels, n=16)
Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.
- Parameters:
images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.
labels (Sequence) – Sequence of labels corresponding to each image.
n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.
- Displays:
A matplotlib figure containing a grid of images, each annotated with its respective label.
- oracle.custom_datasets.BTS.truncate_BTS_light_curve_by_days_since_trigger(x_ts, x_static, d=None, add_jitter=False, normalize_flux=False)
Truncate the BTS light curve based on the number of days since the first trigger. This function selects observations from the time-series data (x_ts) that occur within a specified number of days (d) from the first detection (trigger). It also optionally adds jitter to the flux measurements and normalizes the flux values.
- Parameters:
x_ts (numpy.ndarray) –
Time-dependent features array where each row represents an observation. Expected to contain the following features:
’jd’: Julian date of the observation.
’magpsf’: magnitude.
’sigmapsf’: Uncertainty of the magnitude.
’fid’: Filter identifier.
’photflag’: Photometric flag (added as an extra column to maintain compatibility with ZTF sims and always assumed to be 1).
x_static (numpy.ndarray) – Corresponding static features array for each observation.
d (float, optional) – Maximum number of days from the trigger within which to keep observations. If None, a random threshold is generated using 2^(uniform(0, 11)).
add_jitter (bool, optional) – If True, adds Gaussian noise to the flux by using flux error as 1 sigma values.
normalize_flux (bool, optional) – If True, normalizes the flux (‘magpsf’) using its mean and standard deviation.
- Returns:
- A tuple containing:
x_ts (numpy.ndarray): The truncated time-dependent features array.
x_static (numpy.ndarray): The static features array.
- Return type:
tuple
Note
The function assumes that the dataset does not contain any non-detections.