CASBI.inference#
Module Contents#
Classes#
Custom dataset class for the training and validation datasets. |
|
Custom dataset class for the training and validation datasets. |
Functions#
Train a ltu-ili ensable model on (x, theta) couples. The training data is split into training and validation data. The model is saved in the output_dir directory. |
API#
- class CASBI.inference.CustomDataset_halo(observation, parameters, device)[source]#
Bases:
torch.utils.data.DatasetCustom dataset class for the training and validation datasets.
- observationtorch.Tensor
The observation data.
- parameterstorch.Tensor
The parameter data.
Initialization
- class CASBI.inference.CustomDataset_subhalo(observation, parameters, device)[source]#
Bases:
torch.utils.data.DatasetCustom dataset class for the training and validation datasets.
- observationtorch.Tensor
The observation data.
- parameterstorch.Tensor
The parameter data.
Initialization
- CASBI.inference.train_inference(x: torch.Tensor, theta: torch.Tensor, validation_fraction: float = 0.2, output_dir: str = './', device: str = 'cuda', N_nets=4, hidden_feature: int = 100, num_transforms: int = 20, model: str = 'nsf', embedding_net: str = ConvNet_halo(output_dim=32), custom_dataset: torch.utils.data.Dataset = CustomDataset_halo, custom_dataloader: torch.utils.data.DataLoader = torch.utils.data.DataLoader, minimum_theta: list = [3.5, -2.0], maximum_theta: list = [10, 1.15], batch_size: int = 2048, learning_rate: float = 1e-05, stop_after_epochs: int = 20, norm_x=False, norm_theta=True)[source]#
Train a ltu-ili ensable model on (x, theta) couples. The training data is split into training and validation data. The model is saved in the output_dir directory.
- xtorch.Tensor
The observation data.
- thetatorch.Tensor
The parameter data.
- validation_fractionfloat, optional
The fraction of the data to use for validation, by default 0.2
- output_dirstr, optional
The directory where the model is saved, by default ‘./’
- devicestr, optional
The device to use, by default ‘cuda’
- N_netsint, optional
The number of nets to ensable, by default 4
- hidden_featureint, optional
The number of hidden features of the model, by default 100
- num_transformsint, optional
The number of transforms of the model, by default 10
- modelstr, optional
The model to use, by default ‘nsf’
- embedding_netstr, optional
The embedding network to use, by default ‘CNN’. The CNN.py file should be in the same directory as the training script.
- output_dimint, optional
The output dimension of the embedding network, by default 32
- minimum_thetalist, optional
The minimum value of the theta parameters, by default [3.5, -2.]
- maximum_thetalist, optional
The maximum value of the theta parameters, by default [10, 1.15]
- batch_sizeint, optional
The batch size, by default 1024
- learning_ratefloat, optional
The learning rate, by default 0.00001
- stop_after_epochsint, optional
The number of epochs to train, by default 20
- posterior_ensembleili.posterior.PosteriorEnsemble
The posterior ensemble model.
- summariesdict
The summaries of the training process.