Example #1
0
    def test_train_val_absent(self):
        """Test if an error is raised when the either the train or val baobab config is not passed in

        """
        train_val_dict = copy.deepcopy(self.train_val_dict)
        train_val_dict['data']['val_baobab_cfg_path'] = 'some_path'
        with np.testing.assert_raises(ValueError):
            train_val_cfg = TrainValConfig(train_val_dict)
        train_val_dict = copy.deepcopy(self.train_val_dict)
        train_val_dict['data']['train_baobab_cfg_path'] = 'some_path'
        with np.testing.assert_raises(ValueError):
            train_val_cfg = TrainValConfig(train_val_dict)
Example #2
0
    def test_train_val_config_constructor(self):
        """Test the instantiation of TrainValConfig from a dictionary with minimum required keys

        """
        train_val_dict = copy.deepcopy(self.train_val_dict)
        train_val_dict['data']['train_baobab_cfg_path'] = 'some_path'
        train_val_dict['data']['val_baobab_cfg_path'] = 'some_other_path'
        train_val_cfg = TrainValConfig(train_val_dict)
def main():
    args = script_utils.parse_inference_args()
    test_cfg = TestConfig.from_file(args.test_config_file_path)
    baobab_cfg = BaobabConfig.from_file(test_cfg.data.test_baobab_cfg_path)
    cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
    # Set device and default data type
    device = torch.device(test_cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')
    script_utils.seed_everything(test_cfg.global_seed)
    
    ############
    # Data I/O #
    ############
    train_data = XYData(is_train=True, 
                        Y_cols=cfg.data.Y_cols, 
                        float_type=cfg.data.float_type, 
                        define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens, 
                        rescale_pixels=cfg.data.rescale_pixels, 
                        log_pixels=cfg.data.log_pixels, 
                        add_pixel_noise=cfg.data.add_pixel_noise, 
                        eff_exposure_time=cfg.data.eff_exposure_time, 
                        train_Y_mean=None, 
                        train_Y_std=None, 
                        train_baobab_cfg_path=cfg.data.train_baobab_cfg_path, 
                        val_baobab_cfg_path=test_cfg.data.test_baobab_cfg_path, 
                        for_cosmology=False)
    # Define val data and loader
    test_data = XYData(is_train=False, 
                       Y_cols=cfg.data.Y_cols, 
                       float_type=cfg.data.float_type, 
                       define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens, 
                       rescale_pixels=cfg.data.rescale_pixels, 
                       log_pixels=cfg.data.log_pixels, 
                       add_pixel_noise=cfg.data.add_pixel_noise, 
                       eff_exposure_time=cfg.data.eff_exposure_time, 
                       train_Y_mean=train_data.train_Y_mean, 
                       train_Y_std=train_data.train_Y_std, 
                       train_baobab_cfg_path=cfg.data.train_baobab_cfg_path, 
                       val_baobab_cfg_path=test_cfg.data.test_baobab_cfg_path, 
                       for_cosmology=True)
    cosmo_df = test_data.Y_df
    if test_cfg.data.lens_indices is None:
        n_test = test_cfg.data.n_test # number of lenses in the test set
        lens_range = range(n_test)
    else: # if specific lenses are specified
        lens_range = test_cfg.data.lens_indices
        n_test = len(lens_range)
        print("Performing H0 inference on {:d} specified lenses...".format(n_test))
    batch_size = max(lens_range) + 1
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)
    # Output directory into which the H0 histograms and H0 samples will be saved
    out_dir = test_cfg.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print("Destination folder path: {:s}".format(out_dir))
    else:
        raise OSError("Destination folder already exists.")

    ######################
    # Load trained state #
    ######################
    # Instantiate loss function
    loss_fn = getattr(h0rton.losses, cfg.model.likelihood_class)(Y_dim=train_data.Y_dim, device=device)
    # Instantiate posterior (for logging)
    bnn_post = getattr(h0rton.h0_inference.gaussian_bnn_posterior, loss_fn.posterior_name)(train_data.Y_dim, device, train_data.train_Y_mean, train_data.train_Y_std)
    with torch.no_grad(): # TODO: skip this if lens_posterior_type == 'truth'
        for X_, Y_ in test_loader:
            X = X_.to(device)
            Y = Y_.to(device)
            break

    # Export the input images X for later error analysis
    if test_cfg.export.images:
        for lens_i in range(n_test):
            X_img_path = os.path.join(out_dir, 'X_{0:04d}.npy'.format(lens_i))
            np.save(X_img_path, X[lens_i, 0, :, :].cpu().numpy())

    ################
    # H0 Posterior #
    ################
    h0_prior = getattr(stats, test_cfg.h0_prior.dist)(**test_cfg.h0_prior.kwargs)
    kappa_ext_prior = getattr(stats, test_cfg.kappa_ext_prior.dist)(**test_cfg.kappa_ext_prior.kwargs)
    aniso_param_prior = getattr(stats, test_cfg.aniso_param_prior.dist)(**test_cfg.aniso_param_prior.kwargs)
    # FIXME: hardcoded
    kwargs_model = dict(
                        lens_model_list=['PEMD', 'SHEAR'],
                        lens_light_model_list=['SERSIC_ELLIPSE'],
                        source_light_model_list=['SERSIC_ELLIPSE'],
                        point_source_model_list=['SOURCE_POSITION'],
                        cosmo=FlatLambdaCDM(H0=70.0, Om0=0.3)
                       #'point_source_model_list' : ['LENSED_POSITION']
                       )
    h0_post = H0Posterior(
                          H0_prior=h0_prior,
                          kappa_ext_prior=kappa_ext_prior,
                          aniso_param_prior=aniso_param_prior,
                          exclude_vel_disp=test_cfg.h0_posterior.exclude_velocity_dispersion,
                          kwargs_model=kwargs_model,
                          baobab_time_delays=test_cfg.time_delay_likelihood.baobab_time_delays,
                          kinematics=baobab_cfg.bnn_omega.kinematics,
                          kappa_transformed=test_cfg.kappa_ext_prior.transformed,
                          Om0=baobab_cfg.bnn_omega.cosmology.Om0,
                          define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
                          kwargs_lens_eqn_solver={'min_distance': 0.05, 'search_window': baobab_cfg.instrument['pixel_scale']*baobab_cfg.image['num_pix'], 'num_iter_max': 100}
                          )
    # Get H0 samples for each system
    if not test_cfg.time_delay_likelihood.baobab_time_delays:
        if 'abcd_ordering_i' not in cosmo_df:
            raise ValueError("If the time delay measurements were not generated using Baobab, the user must specify the order of image positions in which the time delays are listed, in order of increasing dec.")
    required_params = h0_post.required_params

    ########################
    # Lens Model Posterior #
    ########################
    n_samples = test_cfg.h0_posterior.n_samples # number of h0 samples per lens
    sampling_buffer = test_cfg.h0_posterior.sampling_buffer # FIXME: dynamically sample more if we run out of samples
    actual_n_samples = int(n_samples*sampling_buffer)

    # Add artificial noise around the truth values
    Y_orig = bnn_post.transform_back_mu(Y).cpu().numpy().reshape(batch_size, test_data.Y_dim)
    Y_orig_df = pd.DataFrame(Y_orig, columns=cfg.data.Y_cols)
    Y_orig_values = Y_orig_df[required_params].values[:, np.newaxis, :] # [n_test, 1, Y_dim]
    artificial_noise = np.random.randn(batch_size, actual_n_samples, test_data.Y_dim)*Y_orig_values*test_cfg.fractional_error_added_to_truth # [n_test, buffer*n_samples, Y_dim]
    lens_model_samples_values = Y_orig_values + artificial_noise # [n_test, buffer*n_samples, Y_dim]

    # Placeholders for mean and std of H0 samples per system
    mean_h0_set = np.zeros(n_test)
    std_h0_set = np.zeros(n_test)
    inference_time_set = np.zeros(n_test)
    # For each lens system...
    total_progress = tqdm(total=n_test)
    sampling_progress = tqdm(total=n_samples)
    prerealized_time_delays = test_cfg.error_model.prerealized_time_delays
    if prerealized_time_delays:
        realized_time_delays = pd.read_csv(test_cfg.error_model.realized_time_delays_path, index_col=None)
    else:
        realized_time_delays = pd.DataFrame()
        realized_time_delays['measured_td_wrt0'] = [[]]*len(lens_range)
    for i, lens_i in enumerate(lens_range):
        lens_i_start_time = time.time()
        # Each lens gets a unique random state for td and vd measurement error realizations.
        rs_lens = np.random.RandomState(lens_i)
        # BNN samples for lens_i
        bnn_sample_df = pd.DataFrame(lens_model_samples_values[lens_i, :, :], columns=required_params)
        # Cosmology observables for lens_i
        cosmo = cosmo_df.iloc[lens_i]
        true_td = np.array(literal_eval(cosmo['true_td']))
        true_img_dec = np.array(literal_eval(cosmo['y_image']))
        true_img_ra = np.array(literal_eval(cosmo['x_image']))
        increasing_dec_i = np.argsort(true_img_dec)
        true_img_dec = true_img_dec[increasing_dec_i]
        true_img_ra = true_img_ra[increasing_dec_i]        
        measured_vd = cosmo['true_vd']*(1.0 + rs_lens.randn()*test_cfg.error_model.velocity_dispersion_frac_error)
        if prerealized_time_delays:
            measured_td_wrt0 = np.array(literal_eval(realized_time_delays.iloc[lens_i]['measured_td_wrt0']))
        else:
            true_td = true_td[increasing_dec_i]
            true_td = true_td[1:] - true_td[0]
            measured_td_wrt0 = true_td + rs_lens.randn(*true_td.shape)*test_cfg.error_model.time_delay_error # [n_img -1,]
            realized_time_delays.at[lens_i, 'measured_td_wrt0'] = list(measured_td_wrt0)
        #print("True: ", true_td)
        #print("True img: ", true_img_dec)
        #print("measured td: ", measured_td_wrt0)
        h0_post.set_cosmology_observables(
                                          z_lens=cosmo['z_lens'], 
                                          z_src=cosmo['z_src'], 
                                          measured_vd=measured_vd, 
                                          measured_vd_err=test_cfg.velocity_dispersion_likelihood.sigma, 
                                          measured_td_wrt0=measured_td_wrt0,
                                          measured_td_err=test_cfg.time_delay_likelihood.sigma, 
                                          abcd_ordering_i=np.arange(len(true_td) + 1),
                                          true_img_dec=true_img_dec,
                                          true_img_ra=true_img_ra,
                                          kappa_ext=cosmo['kappa_ext'], # not necessary
                                          )
        h0_post.set_truth_lens_model(sampled_lens_model_raw=bnn_sample_df.iloc[0])
        # Initialize output array
        h0_samples = np.full(n_samples, np.nan)
        h0_weights = np.zeros(n_samples)
        # For each sample from the lens model posterior of this lens system...
        sampling_progress.reset()
        valid_sample_i = 0
        sample_i = 0
        while valid_sample_i < n_samples:
            if sample_i > actual_n_samples - 1:
                break
            #try:
            # Each sample for a given lens gets a unique random state for H0, k_ext, and aniso_param realizations.
            rs_sample = np.random.RandomState(int(str(lens_i) + str(sample_i).zfill(5)))
            h0, weight = h0_post.get_h0_sample_truth(rs_sample)
            h0_samples[valid_sample_i] = h0
            h0_weights[valid_sample_i] = weight
            sampling_progress.update(1)
            time.sleep(0.001)
            valid_sample_i += 1
            sample_i += 1
            #except:
            #    sample_i += 1
            #    continue
        sampling_progress.refresh()
        lens_i_end_time = time.time()
        inference_time = (lens_i_end_time - lens_i_start_time)/60.0 # min
        h0_dict = dict(
                       h0_samples=h0_samples,
                       h0_weights=h0_weights,
                       n_sampling_attempts=sample_i,
                       measured_td_wrt0=measured_td_wrt0,
                       inference_time=inference_time
                       )
        h0_dict_save_path = os.path.join(out_dir, 'h0_dict_{0:04d}.npy'.format(lens_i))
        np.save(h0_dict_save_path, h0_dict)
        h0_stats = plot_weighted_h0_histogram(h0_samples, h0_weights, lens_i, cosmo['H0'], include_fit_gaussian=test_cfg.plotting.include_fit_gaussian, save_dir=out_dir)
        mean_h0_set[i] = h0_stats['mean']
        std_h0_set[i] = h0_stats['std']
        inference_time_set[i] = inference_time
        total_progress.update(1)
    total_progress.close()
    if not prerealized_time_delays:
        realized_time_delays.to_csv(os.path.join(out_dir, 'realized_time_delays.csv'), index=None)
    h0_stats = dict(
                    mean=mean_h0_set,
                    std=std_h0_set,
                    inference_time=inference_time_set,
                    )
    h0_stats_save_path = os.path.join(out_dir, 'h0_stats')
    np.save(h0_stats_save_path, h0_stats)
def main():
    args = parse_args()
    test_cfg = TestConfig.from_file(args.test_config_file_path)
    train_val_cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
    baobab_cfg = get_baobab_config(test_cfg.data.test_dir)
    # Set device and default data type
    device = torch.device(test_cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')
    seed_everything(test_cfg.global_seed)
    
    ############
    # Data I/O #
    ############
    test_data = XYCosmoData(test_cfg.data.test_dir, data_cfg=train_val_cfg.data)
    master_truth = test_data.cosmo_df
    master_truth = metadata_utils.add_qphi_columns(master_truth)
    master_truth = metadata_utils.add_gamma_psi_ext_columns(master_truth)
    if test_cfg.data.lens_indices is None:
        if args.lens_indices_path is None:
            # Test on all n_test lenses in the test set
            n_test = test_cfg.data.n_test 
            lens_range = range(n_test)
        else:
            # Test on the lens indices in a text file at the specified path
            lens_range = []
            with open(args.lens_indices_path, "r") as f:
                for line in f:
                    lens_range.append(int(line.strip()))
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(n_test))
    else:
        if args.lens_indices_path is None:
            # Test on the lens indices specified in the test config file
            lens_range = test_cfg.data.lens_indices
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(n_test))
        else:
            raise ValueError("Specific lens indices were specified in both the test config file and the command-line argument.")
    batch_size = max(lens_range) + 1
    # Output directory into which the H0 histograms and H0 samples will be saved
    out_dir = test_cfg.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print("Destination folder path: {:s}".format(out_dir))
    else:
        raise OSError("Destination folder already exists.")

    ######################
    # Load trained state #
    ######################
    # Instantiate loss function, to append to the MCMC objective as the prior
    orig_Y_cols = train_val_cfg.data.Y_cols
    # Instantiate MCMC parameter penalty function
    params_to_remove = ['lens_light_R_sersic'] #'src_light_R_sersic'] 
    mcmc_Y_cols = [col for col in orig_Y_cols if col not in params_to_remove]
    mcmc_Y_dim = len(mcmc_Y_cols)
    null_spread = True
    #init_D_dt = np.random.uniform(0.0, 10000.0, size=(batch_size, n_walkers, 1)) # FIXME: init H0 hardcoded

    kwargs_model = dict(lens_model_list=['PEMD', 'SHEAR'],
                        point_source_model_list=['SOURCE_POSITION'],
                        source_light_model_list=['SERSIC_ELLIPSE'])
    astro_sig = test_cfg.image_position_likelihood.sigma
    # Get H0 samples for each system
    if not test_cfg.time_delay_likelihood.baobab_time_delays:
        if 'abcd_ordering_i' not in master_truth:
            raise ValueError("If the time delay measurements were not generated using Baobab, the user must specify the order of image positions in which the time delays are listed, in order of increasing dec.")
    kwargs_lens_eq_solver = {'min_distance': 0.05, 'search_window': baobab_cfg.instrument.pixel_scale*baobab_cfg.image.num_pix, 'num_iter_max': 100}
    #n_walkers = test_cfg.numerics.mcmc.walkerRatio*(mcmc_Y_dim + 1) # BNN params + H0 times walker ratio
    #init_pos = np.tile(master_truth[mcmc_Y_cols].iloc[:batch_size].values[:, np.newaxis, :], [1, n_walkers, 1])
    #init_D_dt = np.random.uniform(0.0, 10000.0, size=(batch_size, n_walkers, 1))
    #print(init_pos.shape, init_D_dt.shape)

    total_progress = tqdm(total=n_test)
    # For each lens system...
    for i, lens_i in enumerate(lens_range):
        # Each lens gets a unique random state for td and vd measurement error realizations.
        rs_lens = np.random.RandomState(lens_i)
        ###########################
        # Relevant data and prior #
        ###########################
        data_i = master_truth.iloc[lens_i].copy()
        # Init values for the lens model params
        init_info = dict(zip(mcmc_Y_cols, data_i[mcmc_Y_cols].values)) # truth params
        lcdm = LCDM(z_lens=data_i['z_lens'], z_source=data_i['z_src'], flat=True)
        true_img_dec = np.array(literal_eval(data_i['y_image']))
        n_img = len(true_img_dec)
        true_td = np.array(literal_eval(data_i['true_td']))
        measured_td = true_td + rs_lens.randn(*true_td.shape)*test_cfg.error_model.time_delay_error
        measured_td_sig = test_cfg.time_delay_likelihood.sigma # np.ones(n_img - 1)*
        measured_img_dec = true_img_dec + rs_lens.randn(n_img)*astro_sig
        increasing_dec_i = np.argsort(true_img_dec) #np.argsort(measured_img_dec)
        measured_td = h0_utils.reorder_to_tdlmc(measured_td, increasing_dec_i, range(n_img)) # need to use measured dec to order
        measured_img_dec = h0_utils.reorder_to_tdlmc(measured_img_dec, increasing_dec_i, range(n_img))
        measured_td_wrt0 = measured_td[1:] - measured_td[0]   
        kwargs_data_joint = dict(time_delays_measured=measured_td_wrt0,
                                 time_delays_uncertainties=measured_td_sig,
                                 )

        #############################
        # Parameter init and bounds #
        #############################
        lens_kwargs = mcmc_utils.get_lens_kwargs(init_info, null_spread=null_spread)
        ps_kwargs = mcmc_utils.get_ps_kwargs_src_plane(init_info, astro_sig, null_spread=null_spread)
        src_light_kwargs = mcmc_utils.get_light_kwargs(init_info['src_light_R_sersic'], null_spread=null_spread)
        special_kwargs = mcmc_utils.get_special_kwargs(n_img, astro_sig, D_dt_sigma=2000, null_spread=null_spread) # image position offset and time delay distance, aka the "special" parameters
        kwargs_params = {'lens_model': lens_kwargs,
                         'point_source_model': ps_kwargs,
                         'source_model': src_light_kwargs,
                         'special': special_kwargs,}
        if test_cfg.numerics.solver_type == 'NONE':
            solver_type = 'NONE'
        else:
            solver_type = 'PROFILE_SHEAR' if n_img == 4 else 'CENTER'
        #solver_type = 'NONE'
        kwargs_constraints = {'num_point_source_list': [n_img],  
                              'Ddt_sampling': True,
                              'solver_type': solver_type,}

        kwargs_likelihood = {'time_delay_likelihood': True,
                             'sort_images_by_dec': True,
                             'prior_lens': [],
                             'prior_special': [],
                             'check_bounds': True, 
                             'check_matched_source_position': False,
                             'source_position_tolerance': 0.01,
                             'source_position_sigma': 0.01,
                             'source_position_likelihood': False,
                             'custom_logL_addition': None,
                             'kwargs_lens_eq_solver': kwargs_lens_eq_solver}

        ###########################
        # MCMC posterior sampling #
        ###########################
        fitting_seq = FittingSequence(kwargs_data_joint, kwargs_model, kwargs_constraints, kwargs_likelihood, kwargs_params, verbose=False, mpi=False)
        if i == 0:
            param_class = fitting_seq._updateManager.param_class
            n_params, param_class_Y_cols = param_class.num_param()
            #init_pos = mcmc_utils.reorder_to_param_class(mcmc_Y_cols, param_class_Y_cols, init_pos, init_D_dt)
        # MCMC sample from the post-processed BNN posterior jointly with cosmology
        lens_i_start_time = time.time()
        #test_cfg.numerics.mcmc.update(init_samples=init_pos[lens_i, :, :])
        fitting_kwargs_list_mcmc = [['MCMC', test_cfg.numerics.mcmc]]
        #with HiddenPrints():
        #try:
        chain_list_mcmc = fitting_seq.fit_sequence(fitting_kwargs_list_mcmc)
        kwargs_result_mcmc = fitting_seq.best_fit()
        #except:
        #    print("lens {:d} skipped".format(lens_i))
        #    total_progress.update(1)
        #    continue
        lens_i_end_time = time.time()
        inference_time = (lens_i_end_time - lens_i_start_time)/60.0 # min

        #############################
        # Plotting the MCMC samples #
        #############################
        # sampler_type : 'EMCEE'
        # samples_mcmc : np.array of shape `[n_mcmc_eval, n_params]`
        # param_mcmc : list of str of length n_params, the parameter names
        sampler_type, samples_mcmc, param_mcmc, _  = chain_list_mcmc[0]
        new_samples_mcmc = mcmc_utils.postprocess_mcmc_chain(kwargs_result_mcmc, samples_mcmc, kwargs_model, lens_kwargs[2], ps_kwargs[2], src_light_kwargs[2], special_kwargs[2], kwargs_constraints)
        # Plot D_dt histogram
        D_dt_samples = new_samples_mcmc['D_dt'].values
        true_D_dt = lcdm.D_dt(H_0=data_i['H0'], Om0=0.3)
        data_i['D_dt'] = true_D_dt
        # Export D_dt samples for this lens
        lens_inference_dict = dict(
                                   D_dt_samples=D_dt_samples, # kappa_ext=0 for these samples
                                   inference_time=inference_time,
                                   true_D_dt=true_D_dt, 
                                   )
        lens_inference_dict_save_path = os.path.join(out_dir, 'D_dt_dict_{0:04d}.npy'.format(lens_i))
        np.save(lens_inference_dict_save_path, lens_inference_dict)
        # Optionally export the MCMC samples
        if test_cfg.export.mcmc_samples:
            mcmc_samples_path = os.path.join(out_dir, 'mcmc_samples_{0:04d}.csv'.format(lens_i))
            new_samples_mcmc.to_csv(mcmc_samples_path, index=None)
        # Optionally export the D_dt histogram
        if test_cfg.export.D_dt_histogram:
            cleaned_D_dt_samples = h0_utils.remove_outliers_from_lognormal(D_dt_samples, 3)
            _ = plotting_utils.plot_D_dt_histogram(cleaned_D_dt_samples, lens_i, true_D_dt, save_dir=out_dir)
        # Optionally export the plot of MCMC chain
        if test_cfg.export.mcmc_chain:
            mcmc_chain_path = os.path.join(out_dir, 'mcmc_chain_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_chain(chain_list_mcmc, mcmc_chain_path)
        # Optionally export posterior cornerplot of select lens model parameters with D_dt
        if test_cfg.export.mcmc_corner:
            mcmc_corner_path = os.path.join(out_dir, 'mcmc_corner_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_corner(new_samples_mcmc[test_cfg.export.mcmc_cols], data_i[test_cfg.export.mcmc_cols], test_cfg.export.mcmc_col_labels, mcmc_corner_path)
        total_progress.update(1)
        gc.collect()
    total_progress.close()
Example #5
0
def main():
    args = parse_args()
    test_cfg = TestConfig.from_file(args.test_config_file_path)
    train_val_cfg = TrainValConfig.from_file(
        test_cfg.train_val_config_file_path)
    # Set device and default data type
    device = torch.device(test_cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')
    seed_everything(test_cfg.global_seed)

    ############
    # Data I/O #
    ############
    test_data = TDLMCData(data_cfg=train_val_cfg.data, rung_i=args.rung_idx)
    master_truth = test_data.cosmo_df
    if test_cfg.data.lens_indices is None:
        if args.lens_indices_path is None:
            # Test on all n_test lenses in the test set
            n_test = test_cfg.data.n_test
            lens_range = range(n_test)
        else:
            # Test on the lens indices in a text file at the specified path
            lens_range = []
            with open(args.lens_indices_path, "r") as f:
                for line in f:
                    lens_range.append(int(line.strip()))
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
    else:
        if args.lens_indices_path is None:
            # Test on the lens indices specified in the test config file
            lens_range = test_cfg.data.lens_indices
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
        else:
            raise ValueError(
                "Specific lens indices were specified in both the test config file and the command-line argument."
            )
    batch_size = max(lens_range) + 1
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             drop_last=True)
    # Output directory into which the H0 histograms and H0 samples will be saved
    out_dir = test_cfg.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print("Destination folder path: {:s}".format(out_dir))
    else:
        raise OSError("Destination folder already exists.")

    ######################
    # Load trained state #
    ######################
    # Instantiate loss function, to append to the MCMC objective as the prior
    orig_Y_cols = train_val_cfg.data.Y_cols
    loss_fn = getattr(h0rton.losses, train_val_cfg.model.likelihood_class)(
        Y_dim=train_val_cfg.data.Y_dim, device=device)
    # Instantiate MCMC parameter penalty function
    params_to_remove = ['lens_light_R_sersic']  #, 'src_light_R_sersic']
    mcmc_Y_cols = [col for col in orig_Y_cols if col not in params_to_remove]
    mcmc_Y_dim = len(mcmc_Y_cols)
    mcmc_loss_fn = getattr(
        h0rton.losses, train_val_cfg.model.likelihood_class)(
            Y_dim=train_val_cfg.data.Y_dim - len(params_to_remove),
            device=device)
    remove_param_idx, remove_idx = mcmc_utils.get_idx_for_params(
        mcmc_loss_fn.out_dim, orig_Y_cols, params_to_remove,
        train_val_cfg.model.likelihood_class)
    mcmc_train_Y_mean = np.delete(train_val_cfg.data.train_Y_mean,
                                  remove_param_idx)
    mcmc_train_Y_std = np.delete(train_val_cfg.data.train_Y_std,
                                 remove_param_idx)
    parameter_penalty = mcmc_utils.HybridBNNPenalty(
        mcmc_Y_cols, train_val_cfg.model.likelihood_class, mcmc_train_Y_mean,
        mcmc_train_Y_std, test_cfg.h0_posterior.exclude_velocity_dispersion,
        device)
    custom_logL_addition = parameter_penalty.evaluate if test_cfg.lens_posterior_type.startswith(
        'default') else None
    null_spread = True if test_cfg.lens_posterior_type == 'truth' else False
    # Instantiate model
    net = getattr(
        h0rton.models,
        train_val_cfg.model.architecture)(num_classes=loss_fn.out_dim)
    net.to(device)
    # Load trained weights from saved state
    net, epoch = train_utils.load_state_dict_test(test_cfg.state_dict_path,
                                                  net,
                                                  train_val_cfg.optim.n_epochs,
                                                  device)
    with torch.no_grad():
        net.eval()
        for X_ in test_loader:
            X = X_.to(device)
            pred = net(X)
            break

    mcmc_pred = pred.cpu().numpy()
    mcmc_pred = mcmc_utils.remove_parameters_from_pred(mcmc_pred,
                                                       remove_idx,
                                                       return_as_tensor=False)

    # Instantiate posterior for BNN samples, to initialize the walkers
    bnn_post = getattr(h0rton.h0_inference.gaussian_bnn_posterior,
                       loss_fn.posterior_name)(mcmc_Y_dim, device,
                                               mcmc_train_Y_mean,
                                               mcmc_train_Y_std)
    bnn_post.set_sliced_pred(torch.tensor(mcmc_pred))
    n_walkers = test_cfg.numerics.mcmc.walkerRatio * (
        mcmc_Y_dim + 1)  # BNN params + H0 times walker ratio
    init_pos = bnn_post.sample(
        n_walkers, sample_seed=test_cfg.global_seed
    )  # [batch_size, n_walkers, mcmc_Y_dim] contains just the lens model params, no D_dt
    init_D_dt = np.random.uniform(0.0,
                                  10000.0,
                                  size=(batch_size, n_walkers,
                                        1))  # FIXME: init H0 hardcoded

    kwargs_model = dict(lens_model_list=['PEMD', 'SHEAR'],
                        point_source_model_list=['SOURCE_POSITION'],
                        source_light_model_list=['SERSIC_ELLIPSE'])
    astro_sig = test_cfg.image_position_likelihood.sigma
    # Get H0 samples for each system
    if not test_cfg.time_delay_likelihood.baobab_time_delays:
        if 'abcd_ordering_i' not in master_truth:
            raise ValueError(
                "If the time delay measurements were not generated using Baobab, the user must specify the order of image positions in which the time delays are listed, in order of increasing dec."
            )

    lenses_skipped = []  # keeps track of lenses that skipped MCMC
    total_progress = tqdm(total=n_test)
    # For each lens system...
    for i, lens_i in enumerate(lens_range):
        # Each lens gets a unique random state for td and vd measurement error realizations.
        rs_lens = np.random.RandomState(lens_i)
        ###########################
        # Relevant data and prior #
        ###########################
        data_i = master_truth.iloc[lens_i].copy()
        parameter_penalty.set_bnn_post_params(
            mcmc_pred[lens_i, :])  # set the BNN parameters
        # Init values for the lens model params
        if test_cfg.lens_posterior_type == 'default':
            init_info = dict(
                zip(
                    mcmc_Y_cols,
                    mcmc_pred[lens_i, :len(mcmc_Y_cols)] * mcmc_train_Y_std +
                    mcmc_train_Y_mean))  # mean of primary Gaussian
        else:  # types 'hybrid_with_truth_mean' and 'truth'
            init_info = dict(zip(mcmc_Y_cols,
                                 data_i[mcmc_Y_cols].values))  # truth params
        if not test_cfg.h0_posterior.exclude_velocity_dispersion:
            parameter_penalty.set_vel_disp_params()
            raise NotImplementedError
        lcdm = LCDM(z_lens=data_i['z_lens'],
                    z_source=data_i['z_src'],
                    flat=True)
        # Data is BCD - A with a certain ABCD ordering, so inferred time delays should follow this convention.
        measured_td_wrt0 = np.array(data_i['measured_td'])  # [n_img - 1,]
        measured_td_sig = np.array(data_i['measured_td_err'])  # [n_img - 1,]
        abcd_ordering_i = np.array(data_i['abcd_ordering_i'])
        n_img = len(abcd_ordering_i)
        kwargs_data_joint = dict(
            time_delays_measured=measured_td_wrt0,
            time_delays_uncertainties=measured_td_sig,
            abcd_ordering_i=abcd_ordering_i,
            #vel_disp_measured=measured_vd, # TODO: optionally exclude
            #vel_disp_uncertainty=vel_disp_sig,
        )
        if not test_cfg.h0_posterior.exclude_velocity_dispersion:
            measured_vd = data_i['true_vd'] * (
                1.0 + rs_lens.randn() *
                test_cfg.error_model.velocity_dispersion_frac_error)
            kwargs_data_joint['vel_disp_measured'] = measured_vd
            kwargs_data_joint[
                'vel_disp_sig'] = test_cfg.velocity_dispersion_likelihood.sigma

        #############################
        # Parameter init and bounds #
        #############################
        lens_kwargs = mcmc_utils.get_lens_kwargs(init_info,
                                                 null_spread=null_spread)
        ps_kwargs = mcmc_utils.get_ps_kwargs_src_plane(init_info,
                                                       astro_sig,
                                                       null_spread=null_spread)
        src_light_kwargs = mcmc_utils.get_light_kwargs(
            init_info['src_light_R_sersic'], null_spread=null_spread)
        special_kwargs = mcmc_utils.get_special_kwargs(
            n_img, astro_sig, null_spread=null_spread
        )  # image position offset and time delay distance, aka the "special" parameters
        kwargs_params = {
            'lens_model': lens_kwargs,
            'point_source_model': ps_kwargs,
            'source_model': src_light_kwargs,
            'special': special_kwargs,
        }
        if test_cfg.numerics.solver_type == 'NONE':
            solver_type = 'NONE'
        else:
            solver_type = 'PROFILE_SHEAR' if n_img == 4 else 'CENTER'
        #solver_type = 'NONE'
        kwargs_constraints = {
            'num_point_source_list': [n_img],
            'Ddt_sampling': True,
            'solver_type': solver_type,
        }

        kwargs_likelihood = {
            'time_delay_likelihood': True,
            'sort_images_by_dec': True,
            'prior_lens': [],
            'prior_special': [],
            'check_bounds': True,
            'check_matched_source_position': False,
            'source_position_tolerance': 0.01,
            'source_position_sigma': 0.01,
            'source_position_likelihood': False,
            'custom_logL_addition': custom_logL_addition,
        }

        ###########################
        # MCMC posterior sampling #
        ###########################
        fitting_seq = FittingSequence(kwargs_data_joint,
                                      kwargs_model,
                                      kwargs_constraints,
                                      kwargs_likelihood,
                                      kwargs_params,
                                      verbose=False,
                                      mpi=False)
        if i == 0:
            param_class = fitting_seq._updateManager.param_class
            n_params, param_class_Y_cols = param_class.num_param()
            init_pos = mcmc_utils.reorder_to_param_class(
                mcmc_Y_cols, param_class_Y_cols, init_pos, init_D_dt)
        # MCMC sample from the post-processed BNN posterior jointly with cosmology
        lens_i_start_time = time.time()
        if test_cfg.lens_posterior_type == 'default':
            test_cfg.numerics.mcmc.update(init_samples=init_pos[lens_i, :, :])
        fitting_kwargs_list_mcmc = [['MCMC', test_cfg.numerics.mcmc]]
        #with HiddenPrints():
        try:
            chain_list_mcmc = fitting_seq.fit_sequence(
                fitting_kwargs_list_mcmc)
            kwargs_result_mcmc = fitting_seq.best_fit()
        except:
            print("lens {:d} skipped".format(lens_i))
            total_progress.update(1)
            lenses_skipped.append(lens_i)
            continue
        lens_i_end_time = time.time()
        inference_time = (lens_i_end_time - lens_i_start_time) / 60.0  # min

        #############################
        # Plotting the MCMC samples #
        #############################
        # sampler_type : 'EMCEE'
        # samples_mcmc : np.array of shape `[n_mcmc_eval, n_params]`
        # param_mcmc : list of str of length n_params, the parameter names
        sampler_type, samples_mcmc, param_mcmc, _ = chain_list_mcmc[0]
        new_samples_mcmc = mcmc_utils.postprocess_mcmc_chain(
            kwargs_result_mcmc, samples_mcmc, kwargs_model, lens_kwargs[2],
            ps_kwargs[2], src_light_kwargs[2], special_kwargs[2],
            kwargs_constraints)
        # Plot D_dt histogram
        D_dt_samples = new_samples_mcmc['D_dt'].values
        true_D_dt = lcdm.D_dt(H_0=data_i['H0'], Om0=0.27)
        data_i['D_dt'] = true_D_dt
        # Export D_dt samples for this lens
        lens_inference_dict = dict(
            D_dt_samples=D_dt_samples,  # kappa_ext=0 for these samples
            inference_time=inference_time,
            true_D_dt=true_D_dt,
        )
        lens_inference_dict_save_path = os.path.join(
            out_dir, 'D_dt_dict_{0:04d}.npy'.format(lens_i))
        np.save(lens_inference_dict_save_path, lens_inference_dict)
        # Optionally export the MCMC samples
        if test_cfg.export.mcmc_samples:
            mcmc_samples_path = os.path.join(
                out_dir, 'mcmc_samples_{0:04d}.csv'.format(lens_i))
            new_samples_mcmc.to_csv(mcmc_samples_path, index=None)
        # Optionally export the D_dt histogram
        if test_cfg.export.D_dt_histogram:
            cleaned_D_dt_samples = h0_utils.remove_outliers_from_lognormal(
                D_dt_samples, 3)
            _ = plotting_utils.plot_D_dt_histogram(cleaned_D_dt_samples,
                                                   lens_i,
                                                   true_D_dt,
                                                   save_dir=out_dir)
        # Optionally export the plot of MCMC chain
        if test_cfg.export.mcmc_chain:
            mcmc_chain_path = os.path.join(
                out_dir, 'mcmc_chain_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_chain(chain_list_mcmc, mcmc_chain_path)
        # Optionally export posterior cornerplot of select lens model parameters with D_dt
        if test_cfg.export.mcmc_corner:
            mcmc_corner_path = os.path.join(
                out_dir, 'mcmc_corner_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_corner(
                new_samples_mcmc[test_cfg.export.mcmc_cols], None,
                test_cfg.export.mcmc_col_labels, mcmc_corner_path)
        total_progress.update(1)
    total_progress.close()
Example #6
0
def main():
    args = script_utils.parse_inference_args()
    test_cfg = TestConfig.from_file(args.test_config_file_path)
    baobab_cfg = BaobabConfig.from_file(test_cfg.data.test_baobab_cfg_path)
    cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
    # Set device and default data type
    device = torch.device(test_cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.' + cfg.data.float_type)
    else:
        torch.set_default_tensor_type('torch.' + cfg.data.float_type)
    script_utils.seed_everything(test_cfg.global_seed)

    ############
    # Data I/O #
    ############
    train_data = XYData(
        is_train=True,
        Y_cols=cfg.data.Y_cols,
        float_type=cfg.data.float_type,
        define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
        rescale_pixels=cfg.data.rescale_pixels,
        rescale_pixels_type=cfg.data.rescale_pixels_type,
        log_pixels=cfg.data.log_pixels,
        add_pixel_noise=cfg.data.add_pixel_noise,
        eff_exposure_time=cfg.data.eff_exposure_time,
        train_Y_mean=None,
        train_Y_std=None,
        train_baobab_cfg_path=cfg.data.train_baobab_cfg_path,
        val_baobab_cfg_path=None,
        for_cosmology=False)
    # Define val data and loader
    test_data = XYData(
        is_train=False,
        Y_cols=cfg.data.Y_cols,
        float_type=cfg.data.float_type,
        define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
        rescale_pixels=cfg.data.rescale_pixels,
        rescale_pixels_type=cfg.data.rescale_pixels_type,
        log_pixels=cfg.data.log_pixels,
        add_pixel_noise=cfg.data.add_pixel_noise,
        eff_exposure_time=cfg.data.eff_exposure_time,
        train_Y_mean=train_data.train_Y_mean,
        train_Y_std=train_data.train_Y_std,
        train_baobab_cfg_path=cfg.data.train_baobab_cfg_path,
        val_baobab_cfg_path=test_cfg.data.test_baobab_cfg_path,
        for_cosmology=True)
    master_truth = test_data.Y_df
    master_truth = metadata_utils.add_qphi_columns(master_truth)
    master_truth = metadata_utils.add_gamma_psi_ext_columns(master_truth)
    # Figure out how many lenses BNN will predict on (must be consecutive)
    if test_cfg.data.lens_indices is None:
        if args.lens_indices_path is None:
            # Test on all n_test lenses in the test set
            n_test = test_cfg.data.n_test
            lens_range = range(n_test)
        else:
            # Test on the lens indices in a text file at the specified path
            lens_range = []
            with open(args.lens_indices_path, "r") as f:
                for line in f:
                    lens_range.append(int(line.strip()))
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
    else:
        if args.lens_indices_path is None:
            # Test on the lens indices specified in the test config file
            lens_range = test_cfg.data.lens_indices
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
        else:
            raise ValueError(
                "Specific lens indices were specified in both the test config file and the command-line argument."
            )
    batch_size = max(lens_range) + 1
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             drop_last=True)
    # Output directory into which the H0 histograms and H0 samples will be saved
    out_dir = test_cfg.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print("Destination folder path: {:s}".format(out_dir))
    else:
        raise OSError("Destination folder already exists.")

    #####################
    # Parameter penalty #
    #####################
    # Instantiate original loss function with all BNN-predicted params
    orig_Y_cols = cfg.data.Y_cols
    loss_fn = getattr(h0rton.losses,
                      cfg.model.likelihood_class)(Y_dim=test_data.Y_dim,
                                                  device=device)
    # Not all predicted params will be sampled via MCMC
    params_to_remove = []  #'lens_light_R_sersic', 'src_light_R_sersic']
    mcmc_Y_cols = [col for col in orig_Y_cols if col not in params_to_remove]
    mcmc_Y_dim = len(mcmc_Y_cols)
    # Instantiate loss function with just the MCMC params
    mcmc_loss_fn = getattr(h0rton.losses, cfg.model.likelihood_class)(
        Y_dim=test_data.Y_dim - len(params_to_remove), device=device)
    remove_param_idx, remove_idx = mcmc_utils.get_idx_for_params(
        mcmc_loss_fn.out_dim, orig_Y_cols, params_to_remove,
        cfg.model.likelihood_class)
    mcmc_train_Y_mean = np.delete(train_data.train_Y_mean, remove_param_idx)
    mcmc_train_Y_std = np.delete(train_data.train_Y_std, remove_param_idx)
    parameter_penalty = mcmc_utils.HybridBNNPenalty(
        mcmc_Y_cols, cfg.model.likelihood_class, mcmc_train_Y_mean,
        mcmc_train_Y_std, test_cfg.h0_posterior.exclude_velocity_dispersion,
        device)
    custom_logL_addition = parameter_penalty.evaluate
    null_spread = False

    ###################
    # BNN predictions #
    ###################
    # Instantiate BNN model
    net = getattr(h0rton.models,
                  cfg.model.architecture)(num_classes=loss_fn.out_dim,
                                          dropout_rate=cfg.model.dropout_rate)
    net.to(device)
    # Load trained weights from saved state
    net, epoch = train_utils.load_state_dict_test(test_cfg.state_dict_path,
                                                  net, cfg.optim.n_epochs,
                                                  device)
    # When only generating BNN predictions (and not running MCMC), we can afford more n_dropout
    # otherwise, we fix n_dropout = mcmc_Y_dim + 1
    if test_cfg.export.pred:
        n_dropout = 20
        n_samples_per_dropout = test_cfg.numerics.mcmc.walkerRatio
    else:
        n_walkers = test_cfg.numerics.mcmc.walkerRatio * (
            mcmc_Y_dim + 1)  # (BNN params + D_dt) times walker ratio
        n_dropout = n_walkers // test_cfg.numerics.mcmc.walkerRatio
        n_samples_per_dropout = test_cfg.numerics.mcmc.walkerRatio
    # Initialize arrays that will store samples and BNN predictions
    init_pos = np.empty(
        [batch_size, n_dropout, n_samples_per_dropout, mcmc_Y_dim])
    mcmc_pred = np.empty([batch_size, n_dropout, mcmc_loss_fn.out_dim])
    with torch.no_grad():
        net.train()
        # Send some empty forward passes through the test data without backprop to adjust batchnorm weights
        # (This is often not necessary. Beware if using for just 1 lens.)
        for nograd_pass in range(5):
            for X_, Y_ in test_loader:
                X = X_.to(device)
                _ = net(X)
        # Obtain MC dropout samples
        for d in range(n_dropout):
            net.eval()
            for X_, Y_ in test_loader:
                X = X_.to(device)
                Y = Y_.to(device)
                pred = net(X)
                break
            mcmc_pred_d = pred.cpu().numpy()
            # Replace BNN posterior's primary gaussian mean with truth values
            if test_cfg.lens_posterior_type == 'default_with_truth_mean':
                mcmc_pred_d[:, :len(mcmc_Y_cols)] = Y[:, :len(mcmc_Y_cols
                                                              )].cpu().numpy()
            # Leave only the MCMC parameters in pred
            mcmc_pred_d = mcmc_utils.remove_parameters_from_pred(
                mcmc_pred_d, remove_idx, return_as_tensor=False)
            # Populate pred that will define the MCMC penalty function
            mcmc_pred[:, d, :] = mcmc_pred_d
            # Instantiate posterior to generate BNN samples, which will serve as initial positions for walkers
            bnn_post = getattr(h0rton.h0_inference.gaussian_bnn_posterior_cpu,
                               loss_fn.posterior_name + 'CPU')(
                                   mcmc_Y_dim, mcmc_train_Y_mean,
                                   mcmc_train_Y_std)
            bnn_post.set_sliced_pred(mcmc_pred_d)
            init_pos[:, d, :, :] = bnn_post.sample(
                n_samples_per_dropout, sample_seed=test_cfg.global_seed +
                d)  # contains just the lens model params, no D_dt
            gc.collect()
    # Terminate right after generating BNN predictions (no MCMC)
    if test_cfg.export.pred:
        import sys
        samples_path = os.path.join(out_dir, 'samples.npy')
        np.save(samples_path, init_pos)
        sys.exit()

    #############
    # MCMC loop #
    #############
    # Convolve MC dropout iterates with aleatoric samples
    init_pos = init_pos.transpose(0, 3, 1, 2).reshape(
        [batch_size, mcmc_Y_dim,
         -1]).transpose(0, 2, 1)  # [batch_size, n_samples, mcmc_Y_dim]
    init_D_dt = np.random.uniform(0.0,
                                  15000.0,
                                  size=(batch_size, n_walkers, 1))
    pred_mean = np.mean(init_pos, axis=1)  # [batch_size, mcmc_Y_dim]
    # Define assumed model profiles
    kwargs_model = dict(lens_model_list=['PEMD', 'SHEAR'],
                        point_source_model_list=['SOURCE_POSITION'],
                        source_light_model_list=['SERSIC_ELLIPSE'])
    astro_sig = test_cfg.image_position_likelihood.sigma  # astrometric uncertainty
    # Get H0 samples for each system
    if not test_cfg.time_delay_likelihood.baobab_time_delays:
        if 'abcd_ordering_i' not in master_truth:
            raise ValueError(
                "If the time delay measurements were not generated using Baobab, the user must specify the order of image positions in which the time delays are listed, in order of increasing dec."
            )
    kwargs_lens_eqn_solver = {
        'min_distance':
        0.05,
        'search_window':
        baobab_cfg.instrument['pixel_scale'] * baobab_cfg.image['num_pix'],
        'num_iter_max':
        200
    }

    total_progress = tqdm(total=n_test)
    realized_time_delays = pd.read_csv(
        test_cfg.error_model.realized_time_delays, index_col=None)
    # For each lens system...
    for i, lens_i in enumerate(lens_range):
        # Each lens gets a unique random state for time delay measurement error realizations.
        #rs_lens = np.random.RandomState(lens_i) # replaced with externally rendered time delays
        ###########################
        # Relevant data and prior #
        ###########################
        data_i = master_truth.iloc[lens_i].copy()
        # Set BNN pred defining parameter penalty for this lens, batch processes across n_dropout
        parameter_penalty.set_bnn_post_params(mcmc_pred[lens_i, :, :])
        # Initialize lens model params walkers at the predictive mean
        init_info = dict(
            zip(mcmc_Y_cols,
                pred_mean[lens_i, :] * mcmc_train_Y_std + mcmc_train_Y_mean))
        lcdm = LCDM(z_lens=data_i['z_lens'],
                    z_source=data_i['z_src'],
                    flat=True)
        true_img_dec = literal_eval(data_i['y_image'])
        n_img = len(true_img_dec)
        measured_td_sig = test_cfg.time_delay_likelihood.sigma
        measured_td_wrt0 = np.array(
            literal_eval(
                realized_time_delays.iloc[lens_i]['measured_td_wrt0']))
        kwargs_data_joint = dict(
            time_delays_measured=measured_td_wrt0,
            time_delays_uncertainties=measured_td_sig,
        )

        #############################
        # Parameter init and bounds #
        #############################
        lens_kwargs = mcmc_utils.get_lens_kwargs(init_info,
                                                 null_spread=null_spread)
        ps_kwargs = mcmc_utils.get_ps_kwargs_src_plane(init_info, astro_sig)
        src_light_kwargs = mcmc_utils.get_light_kwargs(
            init_info['src_light_R_sersic'], null_spread=null_spread)
        special_kwargs = mcmc_utils.get_special_kwargs(
            n_img, astro_sig
        )  # image position offset and time delay distance, aka the "special" parameters
        kwargs_params = {
            'lens_model': lens_kwargs,
            'point_source_model': ps_kwargs,
            'source_model': src_light_kwargs,
            'special': special_kwargs,
        }
        if test_cfg.numerics.solver_type == 'NONE':
            solver_type = 'NONE'
        else:
            solver_type = 'PROFILE_SHEAR' if n_img == 4 else 'CENTER'
        #solver_type = 'NONE'
        kwargs_constraints = {
            'num_point_source_list': [n_img],
            'Ddt_sampling': True,
            'solver_type': solver_type,
        }

        kwargs_likelihood = {
            'time_delay_likelihood': True,
            'sort_images_by_dec': True,
            'prior_lens': [],
            'prior_special': [],
            'check_bounds': True,
            'check_matched_source_position': False,
            'source_position_tolerance': 0.01,
            'source_position_sigma': 0.01,
            'source_position_likelihood': False,
            'custom_logL_addition': custom_logL_addition,
            'kwargs_lens_eqn_solver': kwargs_lens_eqn_solver
        }

        ###########################
        # MCMC posterior sampling #
        ###########################
        fitting_seq = FittingSequence(kwargs_data_joint,
                                      kwargs_model,
                                      kwargs_constraints,
                                      kwargs_likelihood,
                                      kwargs_params,
                                      verbose=False,
                                      mpi=False)
        if i == 0:
            param_class = fitting_seq._updateManager.param_class
            n_params, param_class_Y_cols = param_class.num_param()
            init_pos = mcmc_utils.reorder_to_param_class(
                mcmc_Y_cols, param_class_Y_cols, init_pos, init_D_dt)
        # MCMC sample from the post-processed BNN posterior jointly with cosmology
        lens_i_start_time = time.time()
        if test_cfg.lens_posterior_type == 'default':
            test_cfg.numerics.mcmc.update(init_samples=init_pos[lens_i, :, :])
        fitting_kwargs_list_mcmc = [['MCMC', test_cfg.numerics.mcmc]]
        #try:
        with script_utils.HiddenPrints():
            chain_list_mcmc = fitting_seq.fit_sequence(
                fitting_kwargs_list_mcmc)
            kwargs_result_mcmc = fitting_seq.best_fit()
        lens_i_end_time = time.time()
        inference_time = (lens_i_end_time - lens_i_start_time) / 60.0  # min

        #############################
        # Plotting the MCMC samples #
        #############################
        # sampler_type : 'EMCEE'
        # samples_mcmc : np.array of shape `[n_mcmc_eval, n_params]`
        # param_mcmc : list of str of length n_params, the parameter names
        sampler_type, samples_mcmc, param_mcmc, _ = chain_list_mcmc[0]
        new_samples_mcmc = mcmc_utils.postprocess_mcmc_chain(
            kwargs_result_mcmc, samples_mcmc, kwargs_model, lens_kwargs[2],
            ps_kwargs[2], src_light_kwargs[2], special_kwargs[2],
            kwargs_constraints)
        # Plot D_dt histogram
        D_dt_samples = new_samples_mcmc['D_dt'].values
        true_D_dt = lcdm.D_dt(H_0=data_i['H0'], Om0=0.3)
        data_i['D_dt'] = true_D_dt
        # Export D_dt samples for this lens
        lens_inference_dict = dict(
            D_dt_samples=D_dt_samples,  # kappa_ext=0 for these samples
            inference_time=inference_time,
            true_D_dt=true_D_dt,
        )
        lens_inference_dict_save_path = os.path.join(
            out_dir, 'D_dt_dict_{0:04d}.npy'.format(lens_i))
        np.save(lens_inference_dict_save_path, lens_inference_dict)
        # Optionally export the MCMC samples
        if test_cfg.export.mcmc_samples:
            mcmc_samples_path = os.path.join(
                out_dir, 'mcmc_samples_{0:04d}.csv'.format(lens_i))
            new_samples_mcmc.to_csv(mcmc_samples_path, index=None)
        # Optionally export the D_dt histogram
        if test_cfg.export.D_dt_histogram:
            cleaned_D_dt_samples = h0_utils.remove_outliers_from_lognormal(
                D_dt_samples, 3)
            _ = plotting_utils.plot_D_dt_histogram(cleaned_D_dt_samples,
                                                   lens_i,
                                                   true_D_dt,
                                                   save_dir=out_dir)
        # Optionally export the plot of MCMC chain
        if test_cfg.export.mcmc_chain:
            mcmc_chain_path = os.path.join(
                out_dir, 'mcmc_chain_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_chain(chain_list_mcmc, mcmc_chain_path)
        # Optionally export posterior cornerplot of select lens model parameters with D_dt
        if test_cfg.export.mcmc_corner:
            mcmc_corner_path = os.path.join(
                out_dir, 'mcmc_corner_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_corner(
                new_samples_mcmc[test_cfg.export.mcmc_cols],
                data_i[test_cfg.export.mcmc_cols],
                test_cfg.export.mcmc_col_labels, mcmc_corner_path)
        total_progress.update(1)
        gc.collect()
    realized_time_delays.to_csv(os.path.join(out_dir,
                                             'realized_time_delays.csv'),
                                index=None)
    total_progress.close()
def main():
    args = script_utils.parse_inference_args()
    test_cfg = TestConfig.from_file(args.test_config_file_path)
    baobab_cfg = BaobabConfig.from_file(test_cfg.data.test_baobab_cfg_path)
    cfg = TrainValConfig.from_file(test_cfg.train_val_config_file_path)
    # Set device and default data type
    device = torch.device(test_cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.' + cfg.data.float_type)
    else:
        torch.set_default_tensor_type('torch.' + cfg.data.float_type)
    script_utils.seed_everything(test_cfg.global_seed)

    ############
    # Data I/O #
    ############
    # Define val data and loader
    test_data = XYData(
        is_train=False,
        Y_cols=cfg.data.Y_cols,
        float_type=cfg.data.float_type,
        define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
        rescale_pixels=False,
        rescale_pixels_type=None,
        log_pixels=False,
        add_pixel_noise=cfg.data.add_pixel_noise,
        eff_exposure_time={"TDLMC_F160W": test_cfg.data.eff_exposure_time},
        train_Y_mean=np.zeros((1, len(cfg.data.Y_cols))),
        train_Y_std=np.ones((1, len(cfg.data.Y_cols))),
        train_baobab_cfg_path=cfg.data.train_baobab_cfg_path,
        val_baobab_cfg_path=test_cfg.data.test_baobab_cfg_path,
        for_cosmology=True)
    master_truth = test_data.Y_df
    master_truth = metadata_utils.add_qphi_columns(master_truth)
    master_truth = metadata_utils.add_gamma_psi_ext_columns(master_truth)
    # Figure out how many lenses BNN will predict on (must be consecutive)
    if test_cfg.data.lens_indices is None:
        if args.lens_indices_path is None:
            # Test on all n_test lenses in the test set
            n_test = test_cfg.data.n_test
            lens_range = range(n_test)
        else:
            # Test on the lens indices in a text file at the specified path
            lens_range = []
            with open(args.lens_indices_path, "r") as f:
                for line in f:
                    lens_range.append(int(line.strip()))
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
    else:
        if args.lens_indices_path is None:
            # Test on the lens indices specified in the test config file
            lens_range = test_cfg.data.lens_indices
            n_test = len(lens_range)
            print("Performing H0 inference on {:d} specified lenses...".format(
                n_test))
        else:
            raise ValueError(
                "Specific lens indices were specified in both the test config file and the command-line argument."
            )
    batch_size = max(lens_range) + 1
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             drop_last=True)
    # Output directory into which the H0 histograms and H0 samples will be saved
    out_dir = test_cfg.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print("Destination folder path: {:s}".format(out_dir))
    else:
        warnings.warn("Destination folder already exists.")

    ################
    # Compile data #
    ################
    # Image data
    with torch.no_grad():
        for X_, Y_ in test_loader:
            X = X_.to(device)
            break
    X = X.detach().cpu().numpy()

    #############
    # MCMC loop #
    #############
    kwargs_lens_eqn_solver = dict(
        min_distance=0.05,
        search_window=baobab_cfg.instrument['pixel_scale'] *
        baobab_cfg.image['num_pix'],
        num_iter_max=200)
    fm_posterior = ForwardModelingPosterior(
        kwargs_lens_eqn_solver=kwargs_lens_eqn_solver,
        astrometric_sigma=test_cfg.image_position_likelihood.sigma,
        supersampling_factor=baobab_cfg.numerics.supersampling_factor)
    # Get H0 samples for each system
    if not test_cfg.time_delay_likelihood.baobab_time_delays:
        if 'abcd_ordering_i' not in master_truth:
            raise ValueError(
                "If the time delay measurements were not generated using Baobab, the user must specify the order of image positions in which the time delays are listed, in order of increasing dec."
            )

    total_progress = tqdm(total=n_test)
    realized_time_delays = pd.read_csv(
        test_cfg.error_model.realized_time_delays, index_col=None)
    # For each lens system...
    for i, lens_i in enumerate(lens_range):
        ###########################
        # Relevant data and prior #
        ###########################
        data_i = master_truth.iloc[lens_i].copy()
        lcdm = LCDM(z_lens=data_i['z_lens'],
                    z_source=data_i['z_src'],
                    flat=True)
        measured_td_wrt0 = np.array(
            literal_eval(
                realized_time_delays.iloc[lens_i]['measured_td_wrt0']))
        n_img = len(measured_td_wrt0) + 1
        #print(baobab_cfg.survey_object_dict)
        fm_posterior.set_kwargs_data_joint(
            image=X[lens_i, 0, :, :],
            measured_td=measured_td_wrt0,
            measured_td_sigma=test_cfg.time_delay_likelihood.sigma,
            survey_object_dict=baobab_cfg.survey_object_dict,
            eff_exposure_time=test_cfg.data.eff_exposure_time,
        )
        # Update solver according to number of lensed images
        if test_cfg.numerics.solver_type == 'NONE':
            fm_posterior.kwargs_constraints['solver_type'] = 'NONE'
        else:
            fm_posterior.kwargs_constraints[
                'solver_type'] = 'PROFILE_SHEAR' if n_img == 4 else 'ELLIPSE'
        fm_posterior.kwargs_constraints['num_point_source_list'] = [n_img]
        #print(fm_posterior.kwargs_params['point_source_model'][0][0])
        true_D_dt = lcdm.D_dt(H_0=data_i['H0'], Om0=0.3)
        # Pull truth param values and initialize walkers there
        if test_cfg.numerics.initialize_walkers_to_truth:
            fm_posterior.kwargs_lens_init = metadata_utils.get_kwargs_lens_mass(
                data_i)
            fm_posterior.kwargs_lens_light_init = metadata_utils.get_kwargs_lens_light(
                data_i)
            fm_posterior.kwargs_source_init = metadata_utils.get_kwargs_src_light(
                data_i)
            fm_posterior.kwargs_ps_init = metadata_utils.get_kwargs_ps_lensed(
                data_i)
            fm_posterior.kwargs_special_init = dict(D_dt=true_D_dt)

        ###########################
        # MCMC posterior sampling #
        ###########################
        lens_i_start_time = time.time()
        #with script_utils.HiddenPrints():
        chain_list_mcmc, kwargs_result_mcmc = fm_posterior.run_mcmc(
            test_cfg.numerics.mcmc)
        lens_i_end_time = time.time()
        inference_time = (lens_i_end_time - lens_i_start_time) / 60.0  # min

        #############################
        # Plotting the MCMC samples #
        #############################
        # sampler_type : 'EMCEE'
        # samples_mcmc : np.array of shape `[n_mcmc_eval, n_params]`
        # param_mcmc : list of str of length n_params, the parameter names
        sampler_type, samples_mcmc, param_mcmc, _ = chain_list_mcmc[0]
        new_samples_mcmc = mcmc_utils.postprocess_mcmc_chain(
            kwargs_result_mcmc,
            samples_mcmc,
            fm_posterior.kwargs_model,
            fm_posterior.kwargs_params['lens_model'][2],
            fm_posterior.kwargs_params['point_source_model'][2],
            fm_posterior.kwargs_params['source_model'][2],
            fm_posterior.kwargs_params['special'][2],
            fm_posterior.kwargs_constraints,
            kwargs_fixed_lens_light=fm_posterior.
            kwargs_params['lens_light_model'][2],
            verbose=False)
        #from lenstronomy.Plots import chain_plot
        model_plot = ModelPlot(fm_posterior.multi_band_list,
                               fm_posterior.kwargs_model,
                               kwargs_result_mcmc,
                               arrow_size=0.02,
                               cmap_string="gist_heat")
        plotting_utils.plot_forward_modeling_comparisons(model_plot, out_dir)

        # Plot D_dt histogram
        D_dt_samples = new_samples_mcmc[
            'D_dt'].values  # may contain negative values
        data_i['D_dt'] = true_D_dt
        # Export D_dt samples for this lens
        lens_inference_dict = dict(
            D_dt_samples=D_dt_samples,  # kappa_ext=0 for these samples
            inference_time=inference_time,
            true_D_dt=true_D_dt,
        )
        lens_inference_dict_save_path = os.path.join(
            out_dir, 'D_dt_dict_{0:04d}.npy'.format(lens_i))
        np.save(lens_inference_dict_save_path, lens_inference_dict)
        # Optionally export the MCMC samples
        if test_cfg.export.mcmc_samples:
            mcmc_samples_path = os.path.join(
                out_dir, 'mcmc_samples_{0:04d}.csv'.format(lens_i))
            new_samples_mcmc.to_csv(mcmc_samples_path, index=None)
        # Optionally export the D_dt histogram
        if test_cfg.export.D_dt_histogram:
            cleaned_D_dt_samples = h0_utils.remove_outliers_from_lognormal(
                D_dt_samples, 3)
            _ = plotting_utils.plot_D_dt_histogram(cleaned_D_dt_samples,
                                                   lens_i,
                                                   true_D_dt,
                                                   save_dir=out_dir)
        # Optionally export the plot of MCMC chain
        if test_cfg.export.mcmc_chain:
            mcmc_chain_path = os.path.join(
                out_dir, 'mcmc_chain_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_chain(chain_list_mcmc, mcmc_chain_path)
        # Optionally export posterior cornerplot of select lens model parameters with D_dt
        if test_cfg.export.mcmc_corner:
            mcmc_corner_path = os.path.join(
                out_dir, 'mcmc_corner_{0:04d}.png'.format(lens_i))
            plotting_utils.plot_mcmc_corner(
                new_samples_mcmc[test_cfg.export.mcmc_cols],
                data_i[test_cfg.export.mcmc_cols],
                test_cfg.export.mcmc_col_labels, mcmc_corner_path)
        total_progress.update(1)
        gc.collect()
    realized_time_delays.to_csv(os.path.join(out_dir,
                                             'realized_time_delays.csv'),
                                index=None)
    total_progress.close()
Example #8
0
def main():
    args = parse_args()
    cfg = TrainValConfig.from_file(args.user_cfg_path)
    # Set device and default data type
    device = torch.device(cfg.device_type)
    if device.type == 'cuda':
        torch.set_default_tensor_type('torch.cuda.' + cfg.data.float_type)
    else:
        torch.set_default_tensor_type('torch.' + cfg.data.float_type)
    script_utils.seed_everything(cfg.global_seed)

    ############
    # Data I/O #
    ############

    # Define training data and loader
    #torch.multiprocessing.set_start_method('spawn', force=True)
    train_data = XYData(
        is_train=True,
        Y_cols=cfg.data.Y_cols,
        float_type=cfg.data.float_type,
        define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
        rescale_pixels=cfg.data.rescale_pixels,
        log_pixels=cfg.data.log_pixels,
        add_pixel_noise=cfg.data.add_pixel_noise,
        eff_exposure_time=cfg.data.eff_exposure_time,
        train_Y_mean=None,
        train_Y_std=None,
        train_baobab_cfg_path=cfg.data.train_baobab_cfg_path,
        val_baobab_cfg_path=cfg.data.val_baobab_cfg_path,
        for_cosmology=False)
    train_loader = DataLoader(train_data,
                              batch_size=cfg.optim.batch_size,
                              shuffle=True,
                              drop_last=True)
    n_train = len(train_data) - (len(train_data) % cfg.optim.batch_size)

    # Define val data and loader
    val_data = XYData(is_train=False,
                      Y_cols=cfg.data.Y_cols,
                      float_type=cfg.data.float_type,
                      define_src_pos_wrt_lens=cfg.data.define_src_pos_wrt_lens,
                      rescale_pixels=cfg.data.rescale_pixels,
                      log_pixels=cfg.data.log_pixels,
                      add_pixel_noise=cfg.data.add_pixel_noise,
                      eff_exposure_time=cfg.data.eff_exposure_time,
                      train_Y_mean=train_data.train_Y_mean,
                      train_Y_std=train_data.train_Y_std,
                      train_baobab_cfg_path=cfg.data.train_baobab_cfg_path,
                      val_baobab_cfg_path=cfg.data.val_baobab_cfg_path,
                      for_cosmology=False)
    val_loader = DataLoader(
        val_data,
        batch_size=min(len(val_data), cfg.optim.batch_size),
        shuffle=False,
        drop_last=True,
    )
    n_val = len(val_data) - (len(val_data) %
                             min(len(val_data), cfg.optim.batch_size))

    #########
    # Model #
    #########
    Y_dim = val_data.Y_dim
    # Instantiate loss function
    loss_fn = getattr(h0rton.losses, cfg.model.likelihood_class)(Y_dim=Y_dim,
                                                                 device=device)
    # Instantiate posterior (for logging)
    bnn_post = getattr(h0rton.h0_inference.gaussian_bnn_posterior,
                       loss_fn.posterior_name)(val_data.Y_dim, device,
                                               val_data.train_Y_mean,
                                               val_data.train_Y_std)
    # Instantiate model
    net = getattr(h0rton.models,
                  cfg.model.architecture)(num_classes=loss_fn.out_dim,
                                          dropout_rate=cfg.model.dropout_rate)
    net.to(device)

    ################
    # Optimization #
    ################

    # Instantiate optimizer
    optimizer = optim.Adam(net.parameters(),
                           lr=cfg.optim.learning_rate,
                           amsgrad=False,
                           weight_decay=cfg.optim.weight_decay)
    #optimizer = optim.SGD(net.parameters(), lr=cfg.optim.learning_rate, weight_decay=cfg.optim.weight_decay)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                        mode='min',
                                                        factor=0.75,
                                                        patience=50,
                                                        cooldown=50,
                                                        min_lr=1e-5,
                                                        verbose=True)
    #lr_scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.optim.learning_rate*0.2, max_lr=cfg.optim.learning_rate, step_size_up=cfg.optim.lr_scheduler.step_size_up, step_size_down=None, mode='triangular2', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)

    # Saving/loading state dicts
    checkpoint_dir = cfg.checkpoint.save_dir
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    if cfg.model.load_state:
        epoch, net, optimizer, train_loss, val_loss = train_utils.load_state_dict(
            cfg.model.state_path, net, optimizer, cfg.optim.n_epochs, device)
        epoch += 1  # resume with next epoch
        last_saved_val_loss = val_loss
        print(lr_scheduler.state_dict())
        print(optimizer.state_dict())
    else:
        epoch = 0
        last_saved_val_loss = np.inf

    logger = SummaryWriter()
    model_path = ''
    print("Training set size: {:d}".format(n_train))
    print("Validation set size: {:d}".format(n_val))

    progress = tqdm(range(epoch, cfg.optim.n_epochs))
    n_iter = 0
    for epoch in progress:
        #net.apply(h0rton.models.deactivate_batchnorm)
        train_loss = 0.0
        for batch_idx, (X_tr, Y_tr) in enumerate(train_loader):
            n_iter += 1
            net.train()
            X_tr = X_tr.to(device)
            Y_tr = Y_tr.to(device)
            # Update weights
            optimizer.zero_grad()
            pred_tr = net.forward(X_tr)
            loss = loss_fn(pred_tr, Y_tr)
            loss.backward()
            optimizer.step()
            # For logging
            train_loss += (loss.detach().item() - train_loss) / (1 + batch_idx)
            # Step lr_scheduler every batch
            lr_scheduler.step(train_loss)
            tqdm.write("Iter [{}/{}/{}]: TRAIN Loss: {:.4f}".format(
                n_iter, epoch + 1, cfg.optim.n_epochs, train_loss))

            if (n_iter) % (cfg.monitoring.interval) == 0:
                net.eval()
                with torch.no_grad():
                    #net.apply(h0rton.models.deactivate_batchnorm)
                    val_loss = 0.0

                    for batch_idx, (X_v, Y_v) in enumerate(val_loader):
                        X_v = X_v.to(device)
                        Y_v = Y_v.to(device)
                        pred_v = net.forward(X_v)
                        nograd_loss_v = loss_fn(pred_v, Y_v)
                        val_loss += (nograd_loss_v.detach().item() -
                                     val_loss) / (1 + batch_idx)

                    tqdm.write("Epoch [{}/{}]: VALID Loss: {:.4f}".format(
                        epoch + 1, cfg.optim.n_epochs, val_loss))

                    # Subset of validation for plotting
                    n_plotting = cfg.monitoring.n_plotting
                    #X_plt = X_v[:n_plotting].cpu().numpy()
                    #Y_plt = Y[:n_plotting].cpu().numpy()
                    Y_plt_orig = bnn_post.transform_back_mu(
                        Y_v[:n_plotting]).cpu().numpy()
                    pred_plt = pred_v[:n_plotting]
                    # Slice pred_plt into meaningful Gaussian parameters for this batch
                    bnn_post.set_sliced_pred(pred_plt)
                    mu_orig = bnn_post.transform_back_mu(
                        bnn_post.mu).cpu().numpy()
                    # Log train and val metrics
                    loss_dict = {'train': train_loss, 'val': val_loss}
                    logger.add_scalars('metrics/loss', loss_dict, n_iter)
                    #mae = train_utils.get_mae(mu, Y_plt)
                    mae_dict = train_utils.get_mae(mu_orig, Y_plt_orig,
                                                   cfg.data.Y_cols)
                    logger.add_scalars('metrics/mae', mae_dict, n_iter)
                    # Log log determinant of the covariance matrix

                    if cfg.model.likelihood_class in [
                            'DoubleGaussianNLL', 'FullRankGaussianNLL'
                    ]:
                        logdet = train_utils.get_logdet(
                            bnn_post.tril_elements.cpu().numpy(), Y_dim)
                        logger.add_histogram('logdet_cov_mat', logdet, n_iter)
                    # Log second Gaussian stats
                    if cfg.model.likelihood_class in [
                            'DoubleGaussianNLL', 'DoubleLowRankGaussianNLL'
                    ]:
                        # Log histogram of w2
                        logger.add_histogram('val_pred/weight_gaussian2',
                                             bnn_post.w2.cpu().numpy(), n_iter)
                        # Log RMSE of second Gaussian
                        mu2_orig = bnn_post.transform_back_mu(
                            bnn_post.mu2).cpu().numpy()
                        mae2_dict = train_utils.get_mae(
                            mu2_orig, Y_plt_orig, cfg.data.Y_cols)
                        logger.add_scalars('metrics/mae2', mae2_dict, n_iter)
                        # Log logdet of second Gaussian
                        logdet2 = train_utils.get_logdet(
                            bnn_post.tril_elements2.cpu().numpy(), Y_dim)
                        logger.add_histogram('logdet_cov_mat2', logdet2,
                                             n_iter)

                    if val_loss < last_saved_val_loss:
                        os.remove(model_path) if os.path.exists(
                            model_path) else None
                        model_path = train_utils.save_state_dict(
                            net, optimizer, lr_scheduler, train_loss, val_loss,
                            checkpoint_dir, cfg.model.architecture, epoch)
                        last_saved_val_loss = val_loss

    logger.close()
    # Save final state dict
    if val_loss < last_saved_val_loss:
        os.remove(model_path) if os.path.exists(model_path) else None
        model_path = train_utils.save_state_dict(net, optimizer, lr_scheduler,
                                                 train_loss, val_loss,
                                                 checkpoint_dir,
                                                 cfg.model.architecture, epoch)
        print("Saved model at {:s}".format(os.path.abspath(model_path)))