oracle.custom_datasets.ZTF_sims module
Custom dataset class for the simulated ZTF light curve dataset.
- class oracle.custom_datasets.ZTF_sims.ZTF_SIM_LC_Dataset(parquet_file_path, max_n_per_class=None, include_lc_plots=False, transform=None)
Bases:
Dataset- clean_up_dataset()
Cleans and transforms the dataset stored in the object’s parquet_df attribute. This method performs several dataset cleaning operations including: - Replacing band labels in the “FLT” column with their corresponding mean wavelengths
using the mapping defined in ZTF_passband_to_wavelengths.
- Removing saturated measurements from time series features:
For each time-dependent feature, it removes data points where the corresponding “PHOTFLAG” bitmask indicates saturation (using bitwise logic with 1024).
- Replacing the “PHOTFLAG” bitmask values:
Converts the cleaned “PHOTFLAG” list to binary flags where any instance of the flag 4096 is set to 1 (indicating detection) and 0 otherwise.
- Adjusting the time series:
Normalizes the “MJD_clean” column by subtracting the time of the first observation from all entries.
- Mapping simulation class labels:
Replaces ZTF simulation classes with astrophysical class labels using ZTF_sims_to_Astrophysical_mappings.
- Handling missing data in time-independent features:
Replaces any feature value that matches a missing data flag with a specified flag_value.
The method updates the parquet_df in place and prints progress messages for each step.
- get_all_labels()
Retrieve all labels from the dataset.
This method extracts the ‘class’ column from the parquet dataframe attribute and returns it as a list.
- Returns:
A list containing all labels present in the ‘class’ column.
- Return type:
list
- get_lc_plots(x_ts)
Generates a light curve plot from the provided time series data and returns it as a PyTorch tensor.
- The function performs the following steps:
Extracts light curve parameters such as Julian dates, flux measurements, flux errors, filter identifiers, and photometric flags from the input array.
For each wavelength (as defined in ‘ZTF_wavelength_to_color’), it plots:
Detected data points using the ‘marker_style_detection’.
Non-detected data points using the ‘marker_style_non_detection’. Both with error bars corresponding to flux uncertainties.
Overlays a line plot connecting all points for each wavelength.
Configures the matplotlib figure by setting a fixed size, removing axis ticks and all spines.
Saves the plot into an in-memory PNG buffer at a specified DPI, then loads it via PIL.
Converts the image to a NumPy array, permutes the dimensions, and finally converts it into a PyTorch tensor.
- Parameters:
x_ts (np.ndarray) – A 2D NumPy array representing the time series data.
- Returns:
A tensor representation of the generated light curve plot image.
- Return type:
torch.Tensor
- limit_max_samples_per_class()
Limits the number of samples per class in the dataset.
For every unique class in the ‘class’ column of self.parquet_df, this method selects at most self.max_n_per_class entries (i.e., the first self.max_n_per_class samples) and concatenates them into a new dataframe that replaces self.parquet_df.
- Informative messages are printed to indicate:
The overall limit being applied per class.
The resulting number of samples retained for each class.
- print_dataset_composition()
Prints the composition of the dataset. It formats these values into a pandas DataFrame and prints the resulting table to the console.
- Returns:
None
- oracle.custom_datasets.ZTF_sims.custom_collate_ZTF_SIM(batch)
Collates a batch of ZTF simulation samples into a single dictionary of tensors suitable for model input.
- Each sample in the input batch is expected to be a dictionary with the following keys:
‘ts’: A tensor representing the time series data. The tensor should have shape (num_time_points, …).
‘label’: The label associated with the sample (e.g., class index or regression target).
‘SNID’: A unique identifier for the sample.
‘static’: A tensor of static features with a predefined number of features (n_static_features).
‘lc_plot’: (Optional) A tensor representing the light curve plot image with shape (n_channels, img_height, img_width).
- Parameters:
batch (list) – A list of sample dictionaries as described above.
- Returns:
A dictionary containing the collated batch with the same keys as the input samples, where:
- Return type:
dict
- oracle.custom_datasets.ZTF_sims.show_batch(images, labels, n=16)
Display a grid of images with corresponding labels. This function creates a visual representation of the first n images from the provided dataset. It arranges the images in a square grid and annotates each image with its corresponding label.
- Parameters:
images (Tensor or array-like) – Collection of images to be displayed. Each image is expected to have the shape (C, H, W), where C is the number of channels. For grayscale images, C should be 1.
labels (Sequence) – Sequence of labels corresponding to each image.
n (int, optional) – The number of images to display. The function uses the first n images from the collection. Defaults to 16.
- Displays:
A matplotlib figure containing a grid of images, each annotated with its respective label.
- oracle.custom_datasets.ZTF_sims.truncate_ZTF_SIM_light_curve_by_days_since_trigger(x_ts, d)
Truncates a ZTF SIM light curve by retaining only the observations within a specified number of days since the trigger (first detection).
- Parameters:
x_ts (numpy.ndarray) – A 2D array representing the time series data of the light curve. It is expected to have columns corresponding to various features, including ‘PHOTFLAG’ and ‘MJD’, as specified in the global list ‘time_dependent_feature_list’.
d (float) – The time window (in days) from the first detection. Observations beyond this period are removed.
- Returns:
The truncated light curve array containing only the observations within the specified days since the trigger.
- Return type:
numpy.ndarray
- oracle.custom_datasets.ZTF_sims.truncate_ZTF_SIM_light_curve_fractionally(x_ts, f=None)
Truncate a ZTF simulation light curve by retaining only a fraction of its observations.
- Parameters:
x_ts (numpy.ndarray) – A 2D array representing the light curve where each row corresponds to an observation.
f (float, optional) – Fraction of the total observations to retain. If not provided (None), a random fraction between 0.1 and 1.0 will be used.
- Returns:
A truncated version of the input light curve containing a fraction (at least one) of the original observations.
- Return type:
numpy.ndarray