import yaml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from scipy.stats import gaussian_kde
import pandas as pd
from multiprocessing import cpu_count
from multiprocessing import Pool
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
import ili
from ili.dataloaders import NumpyLoader, TorchLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage, PlotSinglePosterior
from CASBI.utils.CNN import ConvNet_halo, ConvNet_subhalo
from sklearn.model_selection import train_test_split
"""
==========
Inference
==========
In this script we will perform inference on the model trained in the training script. There are also the plot functions to evaluate
on a single test set object and a function to evaluate on the whole test set. Ltu-ili allow for a flexible yaml file interface,
we provide the function to create the yaml file.
"""
[docs]
class CustomDataset_halo(Dataset):
"""
Custom dataset class for the training and validation datasets.
Parameters
----------
observation : torch.Tensor
The observation data.
parameters : torch.Tensor
The parameter data.
"""
def __init__(self, observation, parameters, device):
self.observation = observation
self.parameters = parameters
self.device = device
self.tensors = [self.observation, self.parameters]
[docs]
def __len__(self):
return len(self.observation)
[docs]
def __getitem__(self, idx):
observation = self.observation[idx].to(self.device) #this should put just the batch on the gpu
parameters = self.parameters[idx, :2].to(self.device) #when training we are not interested in the galaxy and subhalo index
return observation, parameters
[docs]
class CustomDataset_subhalo(Dataset, ):
"""
Custom dataset class for the training and validation datasets.
Parameters
----------
observation : torch.Tensor
The observation data.
parameters : torch.Tensor
The parameter data.
"""
def __init__(self, observation, parameters, device):
self.observation = observation
self.parameters = parameters
self.device = device
self.tensors = [self.observation, self.parameters]
[docs]
def __len__(self):
return len(self.observation)
[docs]
def __getitem__(self, idx):
observation = self.observation[idx].to(self.device ) #this should put just the batch on the gpu
parameters = self.parameters[idx].to(self.device ) #when training we are not interested in the galaxy and subhalo index
return observation, parameters
[docs]
def train_inference(x:torch.Tensor,
theta:torch.Tensor,
validation_fraction:float=0.2,
output_dir:str='./',
device:str='cuda',
N_nets=4,
hidden_feature:int=100,
num_transforms:int=20,
model:str='nsf',
embedding_net:str = ConvNet_halo(output_dim=32),
custom_dataset: torch.utils.data.Dataset = CustomDataset_halo,
custom_dataloader: torch.utils.data.DataLoader = torch.utils.data.DataLoader,
minimum_theta:list=[3.5, -2.],
maximum_theta:list=[10, 1.15],
batch_size:int=2048,
learning_rate:float=0.00001,
stop_after_epochs:int=20,
norm_x = False,
norm_theta = True,):
"""
Train a ltu-ili ensable model on (x, theta) couples. The training data is split into training and validation data. The model is saved in the output_dir directory.
Parameters
----------
x : torch.Tensor
The observation data.
theta : torch.Tensor
The parameter data.
validation_fraction : float, optional
The fraction of the data to use for validation, by default 0.2
output_dir : str, optional
The directory where the model is saved, by default './'
device : str, optional
The device to use, by default 'cuda'
N_nets : int, optional
The number of nets to ensable, by default 4
hidden_feature : int, optional
The number of hidden features of the model, by default 100
num_transforms : int, optional
The number of transforms of the model, by default 10
model : str, optional
The model to use, by default 'nsf'
embedding_net : str, optional
The embedding network to use, by default 'CNN'. The CNN.py file should be in the same directory as the training script.
output_dim : int, optional
The output dimension of the embedding network, by default 32
minimum_theta : list, optional
The minimum value of the theta parameters, by default [3.5, -2.]
maximum_theta : list, optional
The maximum value of the theta parameters, by default [10, 1.15]
batch_size : int, optional
The batch size, by default 1024
learning_rate : float, optional
The learning rate, by default 0.00001
stop_after_epochs : int, optional
The number of epochs to train, by default 20
Returns
-------
posterior_ensemble : ili.posterior.PosteriorEnsemble
The posterior ensemble model.
summaries : dict
The summaries of the training process.
"""
#traing arguments
train_args = {
'training_batch_size': batch_size,
'learning_rate': learning_rate,
'stop_after_epochs': stop_after_epochs,
}
#embedding network device
embedding_net.set_device(device)
#ensable model
runner = InferenceRunner.load(
backend ='lampe',
engine ='NPE',
prior = ili.utils.Uniform(low=minimum_theta, high=maximum_theta, device=device),
nets = [ili.utils.load_nde_lampe(model=model,
hidden_features=hidden_feature,
num_transforms=num_transforms,
embedding_net=embedding_net.to(device),
x_normalize=norm_x,
theta_normalize=norm_theta,
device=device) for j in range(N_nets)],
device=device,
train_args=train_args,
proposal=None,
out_dir=output_dir,
)
# Assert if the data are PyTorch tensors
if not torch.is_tensor(x):
x = torch.tensor(x)
if not torch.is_tensor(theta):
theta = torch.tensor(theta)
#split validataion and training data
train_data, val_data, train_targets, val_targets = train_test_split(x, theta, test_size=validation_fraction, random_state=42)
# Now you can create your DataLoaders
train_loader = custom_dataloader(custom_dataset(train_data, train_targets, device=device), shuffle=True, batch_size=batch_size, )
val_loader = custom_dataloader(custom_dataset(val_data, val_targets, device=device), shuffle=False, batch_size=batch_size, )
loader = TorchLoader(train_loader=train_loader, val_loader=val_loader)
posterior_ensemble, summaries = runner(loader=loader,)
return posterior_ensemble, summaries