def test_specify_test_set_path(self):
        # Pass a specific test_set_path to the inference class and make sure
        # it behaves as expected.
        test_set_path = self.root_path

        # Check that the file doesn't already exist.
        self.assertFalse(os.path.isfile(test_set_path + 'tf_record_test_val'))

        # We will again have to simulate training so that the desired
        # normalization path exists.
        model_trainer.prepare_tf_record(self.cfg, self.root_path,
                                        self.tf_record_path, self.final_params,
                                        'train')
        os.remove(self.tf_record_path)

        _ = bnn_inference.InferenceClass(self.cfg,
                                         test_set_path=test_set_path,
                                         lite_class=True)

        # Check that a new tf_record was generated
        self.assertTrue(os.path.isfile(test_set_path + 'tf_record_test_val'))

        # Check that passing a fake test_set_path raises an error.
        fake_test_path = self.root_path + 'fake_data'
        os.mkdir(fake_test_path)

        with self.assertRaises(FileNotFoundError):
            _ = bnn_inference.InferenceClass(self.cfg,
                                             test_set_path=fake_test_path,
                                             lite_class=True)

        # Test cleanup
        os.rmdir(fake_test_path)
        os.remove(test_set_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        os.remove(self.normalization_constants_path)
    def test_undo_param_norm(self):
        # Test if normalizing the lens parameters works correctly.

        self.infer_class = bnn_inference.InferenceClass(self.cfg,
                                                        lite_class=True)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        train_or_test = 'train'
        data_tools.normalize_lens_parameters(self.lens_params,
                                             self.lens_params_path,
                                             self.normalized_param_path,
                                             self.normalization_constants_path,
                                             train_or_test=train_or_test)

        lens_params_csv = pd.read_csv(self.lens_params_path, index_col=None)
        norm_params_csv = pd.read_csv(self.normalized_param_path,
                                      index_col=None)

        # Pull lens parameters out of the csv files.
        lens_params_numpy = []
        norms_params_numpy = []
        for lens_param in self.lens_params:
            lens_params_numpy.append(lens_params_csv[lens_param])
            norms_params_numpy.append(norm_params_csv[lens_param])
        lens_params_numpy = np.array(lens_params_numpy).T
        norms_params_numpy = np.array(norms_params_numpy).T
        predict_samps = np.tile(norms_params_numpy, (3, 1, 1))
        # TODO: write a good test for al_samps!
        al_samps = np.ones((3, 3, self.num_params, self.num_params))

        # Try to denormalize everything
        self.infer_class.undo_param_norm(predict_samps, norms_params_numpy,
                                         al_samps)

        self.assertAlmostEqual(
            np.mean(np.abs(norms_params_numpy - lens_params_numpy)), 0)
        self.assertAlmostEqual(
            np.mean(np.abs(predict_samps - lens_params_numpy)), 0)

        # Clean up the file now that we're done
        os.remove(self.normalized_param_path)
        os.remove(self.normalization_constants_path)
    def test_fix_flip_pairs(self):
        # Check that fix_flip_pairs always selects the best possible configuration
        # to return.
        self.infer_class = bnn_inference.InferenceClass(self.cfg,
                                                        lite_class=True)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # Get the set of all flip pairs we want to check
        flip_pairs = self.cfg['training_params']['flip_pairs']
        flip_set = set()
        for flip_pair in flip_pairs:
            flip_set.update(flip_pair)

        y_test = np.ones((self.batch_size, self.num_params))
        predict_samps = np.ones((10, self.batch_size, self.num_params))

        pi = 0
        for flip_index in flip_set:
            predict_samps[pi, :, flip_index] = -1

        # Flip pairs of points.
        self.infer_class.fix_flip_pairs(predict_samps, y_test, self.batch_size)

        self.assertEqual(np.sum(np.abs(predict_samps - y_test)), 0)

        dont_flip_set = set(range(self.num_params))
        dont_flip_set = dont_flip_set.difference(flip_set)

        pi = 0
        for flip_index in dont_flip_set:
            predict_samps[pi, :, flip_index] = -1

        # Flip pairs of points.
        self.infer_class.fix_flip_pairs(predict_samps, y_test, self.batch_size)

        self.assertEqual(np.sum(np.abs(predict_samps - y_test)),
                         2 * self.batch_size * len(dont_flip_set))
    def test_calc_p_dlt(self):

        self.infer_class = bnn_inference.InferenceClass(self.cfg,
                                                        lite_class=True)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # Test that the calc_p_dlt returns the correct percentages for some
        # toy examples

        # Check a simple case
        size = int(1e6)
        self.infer_class.predict_samps = np.random.normal(size=size *
                                                          2).reshape(
                                                              (size // 10, 10,
                                                               2))
        self.infer_class.predict_samps[:, :, 1] = 0
        self.infer_class.y_pred = np.mean(self.infer_class.predict_samps,
                                          axis=0)
        self.infer_class.y_test = np.array(
            [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
            dtype=np.float32).T

        self.infer_class.calc_p_dlt(cov_emp=np.diag(np.ones(2)))
        percentages = [0.682689, 0.954499, 0.997300, 0.999936, 0.999999
                       ] + [1.0] * 5
        for p_i in range(len(percentages)):
            self.assertAlmostEqual(percentages[p_i],
                                   self.infer_class.p_dlt[p_i],
                                   places=2)

        # Shift the mean
        size = int(1e6)
        self.infer_class.predict_samps = np.random.normal(
            loc=2, size=size * 2).reshape((size // 10, 10, 2))
        self.infer_class.predict_samps[:, :, 1] = 0
        self.infer_class.y_pred = np.mean(self.infer_class.predict_samps,
                                          axis=0)
        self.infer_class.y_test = np.array(
            [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
            dtype=np.float32).T
        self.infer_class.calc_p_dlt(cov_emp=np.diag(np.ones(2)))
        percentages = [0.682689, 0, 0.682689, 0.954499, 0.997300, 0.999936
                       ] + [1.0] * 4
        for p_i in range(len(percentages)):
            self.assertAlmostEqual(percentages[p_i],
                                   self.infer_class.p_dlt[p_i],
                                   places=2)

        # Expand to higher dimensions
        size = int(1e6)
        self.infer_class.predict_samps = np.random.normal(
            loc=0, size=size * 2).reshape((size // 10, 10, 2))
        self.infer_class.predict_samps /= np.sqrt(
            np.sum(np.square(self.infer_class.predict_samps),
                   axis=-1,
                   keepdims=True))
        self.infer_class.predict_samps *= np.random.random(size=size).reshape(
            (size // 10, 10, 1)) * 5
        self.infer_class.y_pred = np.mean(self.infer_class.predict_samps,
                                          axis=0)
        self.infer_class.y_test = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                                            [0] * 10]).T
        self.infer_class.calc_p_dlt(cov_emp=np.diag(np.ones(2)))
        percentages = [1 / 5, 2 / 5, 3 / 5, 4 / 5, 1, 1] + [1.0] * 4
        for p_i in range(len(percentages)):
            self.assertAlmostEqual(percentages[p_i],
                                   self.infer_class.p_dlt[p_i],
                                   places=2)

        # Expand to higher dimensions
        size = int(1e6)
        self.infer_class.predict_samps = np.random.normal(
            loc=0, size=size * 2).reshape((size // 2, 2, 2)) * 5
        self.infer_class.predict_samps[:, :, 1] = 0
        self.infer_class.y_pred = np.mean(self.infer_class.predict_samps,
                                          axis=0)
        self.infer_class.y_test = np.array([[0, np.sqrt(2)], [0] * 2]).T
        self.infer_class.calc_p_dlt()
        percentages = [0, 0.223356]
        for p_i in range(len(percentages)):
            self.assertAlmostEqual(percentages[p_i],
                                   self.infer_class.p_dlt[p_i],
                                   places=2)
    def test_gen_samples_save(self):

        self.infer_class = bnn_inference.InferenceClass(self.cfg)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # First we have to make a fake model whose statistics are very well
        # defined.
        class ToyModel():
            def __init__(self, mean, covariance, batch_size, al_std):
                # We want to make sure our performance is consistent for a
                # test
                np.random.seed(4)
                self.mean = mean
                self.covariance = covariance
                self.batch_size = batch_size
                self.al_std = al_std

            def predict(self, image):
                # We won't actually be using the image. We just want it for
                # testing.
                return tf.constant(
                    np.concatenate([
                        np.random.multivariate_normal(
                            self.mean, self.covariance, self.batch_size),
                        np.zeros(
                            (self.batch_size, len(self.mean))) + self.al_std
                    ],
                                   axis=-1), tf.float32)

        # Start with a simple covariance matrix example.
        mean = np.ones(self.num_params) * 2
        covariance = np.diag(np.ones(self.num_params))
        al_std = -1000
        diag_model = ToyModel(mean, covariance, self.batch_size, al_std)

        # We don't want any flipping going on
        self.infer_class.flip_mat_list = [np.diag(np.ones(self.num_params))]

        # Create tf record. This won't be used, but it has to be there for
        # the function to be able to pull some images.
        # Make fake norms data
        fake_norms = {}
        for lens_param in self.lens_params:
            fake_norms[lens_param] = np.array([0.0, 1.0])
        fake_norms = pd.DataFrame(data=fake_norms)
        fake_norms.to_csv(self.normalization_constants_path, index=False)
        data_tools.generate_tf_record(self.root_path, self.lens_params,
                                      self.lens_params_path,
                                      self.tf_record_path)

        # Replace the real model with our fake model and generate samples
        self.infer_class.model = diag_model
        # Provide a save path to then check that we get the same data
        save_path = self.root_path + 'test_gen_samps/'
        self.infer_class.gen_samples(10000, save_path)

        pred_1 = np.copy(self.infer_class.predict_samps)
        # Generate again and make sure they are equivalent
        self.infer_class.gen_samples(10000, save_path)

        np.testing.assert_almost_equal(pred_1, self.infer_class.predict_samps)

        # Test that none of the plotting routines break
        self.infer_class.gen_coverage_plots(block=False)
        plt.close('all')
        self.infer_class.report_stats()
        self.infer_class.plot_posterior_contours(1, block=False)
        plt.close('all')
        plt.close('all')
        self.infer_class.comp_al_ep_unc(block=False)
        plt.close('all')
        self.infer_class.comp_al_ep_unc(block=False, norm_diagonal=False)
        plt.close('all')
        self.infer_class.plot_calibration(block=False, title='test')
        plt.close('all')

        # Clean up the files we generated
        os.remove(self.normalization_constants_path)
        os.remove(self.tf_record_path)
        os.remove(save_path + 'pred.npy')
        os.remove(save_path + 'al_cov.npy')
        os.remove(save_path + 'images.npy')
        os.remove(save_path + 'y_test.npy')
        os.rmdir(save_path)
    def test_gen_samples_gmm(self):

        self.infer_class = bnn_inference.InferenceClass(self.cfg)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # First we have to make a fake model whose statistics are very well
        # defined.

        class ToyModel():
            def __init__(self, mean1, covariance1, mean2, covariance2,
                         batch_size, L_elements1, L_elements2, pi_logit):
                # We want to make sure our performance is consistent for a
                # test
                np.random.seed(6)
                self.mean1 = mean1
                self.mean2 = mean2
                self.covariance1 = covariance1
                self.covariance2 = covariance2
                self.num_params = len(mean1)
                self.batch_size = batch_size
                self.L_elements1 = L_elements1
                self.L_elements2 = L_elements2
                self.pi_logit = pi_logit
                self.L_elements_len = int(self.num_params *
                                          (self.num_params + 1) / 2)

            def predict(self, image):
                # We won't actually be using the image. We just want it for
                # testing.
                return tf.constant(
                    np.concatenate([
                        np.random.multivariate_normal(
                            self.mean1, self.covariance1, self.batch_size),
                        np.zeros((self.batch_size, self.L_elements_len)) +
                        self.L_elements1,
                        np.random.multivariate_normal(
                            self.mean2, self.covariance2, self.batch_size),
                        np.zeros((self.batch_size, self.L_elements_len)) +
                        self.L_elements2,
                        np.zeros((self.batch_size, 1)) + self.pi_logit
                    ],
                                   axis=-1), tf.float32)

        # Start with a simple covariance matrix example where both gmms
        # are the same. This is just checking the base case.
        mean1 = np.ones(self.num_params) * 2
        mean2 = np.ones(self.num_params) * 2
        covariance1 = np.diag(np.ones(self.num_params) * 0.000001)
        covariance2 = np.diag(np.ones(self.num_params) * 0.000001)
        L_elements1 = np.array([np.log(1)] * self.num_params +
                               [0] * int(self.num_params *
                                         (self.num_params - 1) / 2))
        L_elements2 = np.array([np.log(1)] * self.num_params +
                               [0] * int(self.num_params *
                                         (self.num_params - 1) / 2))
        pi_logit = 0
        gmm_model = ToyModel(mean1, covariance1, mean2, covariance2,
                             self.batch_size, L_elements1, L_elements2,
                             pi_logit)

        # We don't want any flipping going on
        self.infer_class.flip_mat_list = [np.diag(np.ones(self.num_params))]

        # Create tf record. This won't be used, but it has to be there for
        # the function to be able to pull some images.
        # Make fake norms data
        fake_norms = {}
        for lens_param in self.lens_params:
            fake_norms[lens_param] = np.array([0.0, 1.0])
        fake_norms = pd.DataFrame(data=fake_norms)
        fake_norms.to_csv(self.normalization_constants_path, index=False)
        data_tools.generate_tf_record(self.root_path, self.lens_params,
                                      self.lens_params_path,
                                      self.tf_record_path)

        # Replace the real model with our fake model and generate samples
        self.infer_class.model = gmm_model
        self.infer_class.bnn_type = 'gmm'
        self.infer_class.gen_samples(1000)

        # Make sure these samples follow the required statistics.
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred -
                                              mean1)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_std - 1)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_cov - np.eye(self.num_params))),
                               0,
                               places=1)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(
            np.mean(np.abs(self.infer_class.al_cov - np.eye(self.num_params))),
            0)

        # Now we try and example where all the samples should be drawn from one
        # of the two gmms because of the logit.
        mean1 = np.ones(self.num_params) * 2
        mean2 = np.ones(self.num_params) * 200
        covariance1 = np.diag(np.ones(self.num_params) * 0.000001)
        covariance2 = np.diag(np.ones(self.num_params) * 0.000001)
        L_elements1 = np.array([np.log(1)] * self.num_params +
                               [0] * int(self.num_params *
                                         (self.num_params - 1) / 2))
        L_elements2 = np.array([np.log(10)] * self.num_params +
                               [0] * int(self.num_params *
                                         (self.num_params - 1) / 2))
        pi_logit = np.log(0.99999) - np.log(0.00001)
        gmm_model = ToyModel(mean1, covariance1, mean2, covariance2,
                             self.batch_size, L_elements1, L_elements2,
                             pi_logit)
        self.infer_class.model = gmm_model
        self.infer_class.gen_samples(1000)

        # Make sure these samples follow the required statistics.
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred -
                                              mean1)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_std - 1)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_cov - np.eye(self.num_params))),
                               0,
                               places=1)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(
            np.mean(np.abs(self.infer_class.al_cov - np.eye(self.num_params))),
            0)

        # Now test that it takes a combination of them correctly
        mean1 = np.ones(self.num_params) * 2
        mean2 = np.ones(self.num_params) * 6
        covariance1 = np.diag(np.ones(self.num_params) * 0.000001)
        covariance2 = np.diag(np.ones(self.num_params) * 0.000001)
        L_elements1 = np.array([np.log(10)] * self.num_params +
                               [0] * int(self.num_params *
                                         (self.num_params - 1) / 2))
        L_elements2 = np.array([np.log(1)] * self.num_params +
                               [0] * int(self.num_params *
                                         (self.num_params - 1) / 2))
        pi_logit = np.log(0.0001) - np.log(0.9999)
        gmm_model = ToyModel(mean1, covariance1, mean2, covariance2,
                             self.batch_size, L_elements1, L_elements2,
                             pi_logit)
        self.infer_class.model = gmm_model
        self.infer_class.gen_samples(2000)

        # Make sure these samples follow the required statistics.
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred - 4)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_std - np.sqrt(5))),
                               0,
                               places=0)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))

        # The first Gaussian is always favored in the current parameterization,
        # so we can't test the scenario where the second is favored.

        # Clean up the files we generated
        os.remove(self.normalization_constants_path)
        os.remove(self.tf_record_path)
    def test_gen_samples_full(self):

        self.infer_class = bnn_inference.InferenceClass(self.cfg)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # First we have to make a fake model whose statistics are very well
        # defined.

        class ToyModel():
            def __init__(self, mean, covariance, batch_size, L_elements):
                # We want to make sure our performance is consistent for a
                # test
                np.random.seed(6)
                self.mean = mean
                self.num_params = len(mean)
                self.covariance = covariance
                self.batch_size = batch_size
                self.L_elements = L_elements
                self.L_elements_len = int(self.num_params *
                                          (self.num_params + 1) / 2)

            def predict(self, image):
                # We won't actually be using the image. We just want it for
                # testing.
                return tf.constant(
                    np.concatenate([
                        np.zeros(
                            (self.batch_size, self.num_params)) + self.mean,
                        np.zeros((self.batch_size, self.L_elements_len)) +
                        self.L_elements
                    ],
                                   axis=-1), tf.float32)

        # Start with a simple covariance matrix example.
        mean = np.ones(self.num_params) * 2
        covariance = np.diag(np.ones(self.num_params) * 0.000001)
        L_elements = np.array([np.log(1)] * self.num_params +
                              [0] * int(self.num_params *
                                        (self.num_params - 1) / 2))
        full_model = ToyModel(mean, covariance, self.batch_size, L_elements)

        # We don't want any flipping going on
        self.infer_class.flip_mat_list = [np.diag(np.ones(self.num_params))]

        # Create tf record. This won't be used, but it has to be there for
        # the function to be able to pull some images.
        # Make fake norms data
        fake_norms = {}
        for lens_param in self.lens_params:
            fake_norms[lens_param] = np.array([0.0, 1.0])
        fake_norms = pd.DataFrame(data=fake_norms)
        fake_norms.to_csv(self.normalization_constants_path, index=False)
        data_tools.generate_tf_record(self.root_path, self.lens_params,
                                      self.lens_params_path,
                                      self.tf_record_path)

        # Replace the real model with our fake model and generate samples
        self.infer_class.model = full_model
        self.infer_class.bnn_type = 'full'
        # self.infer_class.gen_samples(1000)

        # # Make sure these samples follow the required statistics.
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred-mean)),
        # 	0,places=1)
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_std-1)),0,
        # 	places=1)
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_cov-np.eye(
        # 	self.num_params))),0,places=1)
        # self.assertTupleEqual(self.infer_class.al_cov.shape,(self.batch_size,
        # 	self.num_params,self.num_params))
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.al_cov-np.eye(
        # 	self.num_params))),0)

        mean = np.zeros(self.num_params)
        loss_class = bnn_alexnet.LensingLossFunctions([], self.num_params)
        L_elements = np.ones((1, len(L_elements))) * 0.2
        full_model = ToyModel(mean, covariance, self.batch_size, L_elements)
        self.infer_class.model = full_model
        self.infer_class.gen_samples(1000)

        # Calculate the corresponding covariance matrix
        _, _, L_mat = loss_class.construct_precision_matrix(
            tf.constant(L_elements))
        L_mat = np.linalg.inv(L_mat.numpy()[0].T)
        cov_mat = np.dot(L_mat, L_mat.T)

        # Make sure these samples follow the required statistics.
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred - mean)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_std - np.sqrt(np.diag(cov_mat)))),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs((self.infer_class.y_cov - cov_mat))),
                               0,
                               places=1)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(
            np.mean(np.abs(self.infer_class.al_cov - cov_mat)), 0)

        # Clean up the files we generated
        os.remove(self.normalization_constants_path)
        os.remove(self.tf_record_path)
    def test_gen_samples_diag(self):

        self.infer_class = bnn_inference.InferenceClass(self.cfg)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # First we have to make a fake model whose statistics are very well
        # defined.

        class ToyModel():
            def __init__(self, mean, covariance, batch_size, al_std):
                # We want to make sure our performance is consistent for a
                # test
                np.random.seed(4)
                self.mean = mean
                self.covariance = covariance
                self.batch_size = batch_size
                self.al_std = al_std

            def predict(self, image):
                # We won't actually be using the image. We just want it for
                # testing.
                return tf.constant(
                    np.concatenate([
                        np.random.multivariate_normal(
                            self.mean, self.covariance, self.batch_size),
                        np.zeros(
                            (self.batch_size, len(self.mean))) + self.al_std
                    ],
                                   axis=-1), tf.float32)

        # Start with a simple covariance matrix example.
        mean = np.ones(self.num_params) * 2
        covariance = np.diag(np.ones(self.num_params))
        al_std = -1000
        diag_model = ToyModel(mean, covariance, self.batch_size, al_std)

        # We don't want any flipping going on
        self.infer_class.flip_mat_list = [np.diag(np.ones(self.num_params))]

        # Create tf record. This won't be used, but it has to be there for
        # the function to be able to pull some images.
        # Make fake norms data
        fake_norms = {}
        for lens_param in self.lens_params:
            fake_norms[lens_param] = np.array([0.0, 1.0])
        fake_norms = pd.DataFrame(data=fake_norms)
        fake_norms.to_csv(self.normalization_constants_path, index=False)
        data_tools.generate_tf_record(self.root_path, self.lens_params,
                                      self.lens_params_path,
                                      self.tf_record_path)

        # Replace the real model with our fake model and generate samples
        self.infer_class.model = diag_model
        self.infer_class.gen_samples(10000)

        # Make sure these samples follow the required statistics.
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred - mean)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_std - np.diag(covariance))),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_cov - covariance)),
                               0,
                               places=1)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.al_cov)), 0)

        # Repeat this process again with a new covariance matrix and means
        mean = np.random.rand(self.num_params)
        covariance = np.random.rand(self.num_params, self.num_params)
        al_std = 0
        # Make sure covariance is positive semidefinite
        covariance = np.dot(covariance, covariance.T)
        diag_model = ToyModel(mean, covariance, self.batch_size, al_std)
        self.infer_class.model = diag_model
        self.infer_class.gen_samples(10000)
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred - mean)),
                               0,
                               places=1)
        # Covariance is the sum of two random variables
        covariance = covariance + np.eye(self.num_params)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_std - np.sqrt(np.diag(covariance)))),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_cov - covariance)),
                               0,
                               places=1)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(
            np.mean(np.abs(self.infer_class.al_cov - np.eye(self.num_params))),
            0)

        # Make sure our test probes things well.
        wrong_mean = np.random.randn(self.num_params)
        wrong_covariance = np.random.rand(self.num_params, self.num_params)
        al_std = -1000
        # Make sure covariance is positive semidefinite
        wrong_covariance = np.dot(wrong_covariance, wrong_covariance.T)
        diag_model = ToyModel(wrong_mean, wrong_covariance, self.batch_size,
                              al_std)
        self.infer_class.model = diag_model
        self.infer_class.gen_samples(10000)
        self.assertGreater(np.mean(np.abs(self.infer_class.y_pred - mean)),
                           0.05)
        self.assertGreater(
            np.mean(
                np.abs(self.infer_class.y_std - np.sqrt(np.diag(covariance)))),
            0.05)
        self.assertGreater(
            np.mean(np.abs(self.infer_class.y_cov - covariance)), 0.05)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.al_cov)), 0)

        # Clean up the files we generated
        os.remove(self.normalization_constants_path)
        os.remove(self.tf_record_path)
	def __init__(self,cfg,interim_baobab_omega_path,target_ovejero_omega_path,
		test_dataset_path,test_dataset_tf_record_path,
		target_baobab_omega_path=None,train_to_test_param_map=None,
		lite_class=False):
		# Initialzie our class.
		self.cfg = cfg
		# Pull the needed param information from the config file.
		self.lens_params_train = cfg['dataset_params']['lens_params']
		self.lens_params_test = copy.deepcopy(self.lens_params_train)
		# We will need to encode the difference between the test and train
		# parameter names.
		if train_to_test_param_map is not None:
			self.lens_params_change_ind = []
			# Go through each parameter, mark its index, and make the swap
			for li, lens_param in enumerate(
				train_to_test_param_map['orig_params']):
				self.lens_params_change_ind.append(self.lens_params_train.index(
					lens_param))
				self.lens_params_test[self.lens_params_change_ind[-1]] = (
					train_to_test_param_map['new_params'][li])

		self.lens_params_log = cfg['dataset_params']['lens_params_log']
		self.gampsi = cfg['dataset_params']['gampsi']
		self.final_params = cfg['training_params']['final_params']

		# Read the config files and turn them into evaluation dictionaries
		self.interim_baobab_omega = configs.BaobabConfig.from_file(
			interim_baobab_omega_path)
		self.target_baobab_omega = load_prior_config(target_ovejero_omega_path)
		self.interim_eval_dict = build_eval_dict(self.interim_baobab_omega,
			self.lens_params_train,baobab_config=True)
		self.target_eval_dict = build_eval_dict(self.target_baobab_omega,
			self.lens_params_test,baobab_config=False)
		self.train_to_test_param_map = train_to_test_param_map

		# Get the number of parameters and set the batch size to the full
		# test set.
		self.num_params = len(self.lens_params_train)
		self.norm_images = cfg['training_params']['norm_images']
		n_npy_files = len(glob.glob(os.path.join(test_dataset_path,'X*.npy')))
		self.cfg['training_params']['batch_size'] = n_npy_files

		# Make our inference class we'll use to generate samples.
		self.infer_class = bnn_inference.InferenceClass(self.cfg,lite_class)

		# The inference class will load the validation set from the config
		# file. We do not want this. Therefore we must reset it here.
		if not os.path.exists(test_dataset_tf_record_path):
			print('Generating new TFRecord at %s'%(test_dataset_tf_record_path))
			model_trainer.prepare_tf_record(cfg,test_dataset_path,
				test_dataset_tf_record_path,self.final_params,
				train_or_test='test')
		else:
			print('TFRecord found at %s'%(test_dataset_tf_record_path))
		self.tf_dataset = data_tools.build_tf_dataset(
			test_dataset_tf_record_path,self.final_params,n_npy_files,1,
			target_baobab_omega_path,norm_images=self.norm_images)
		self.infer_class.tf_dataset_v = self.tf_dataset

		# Track if the sampler has been initialzied yet.
		self.sampler_init = False

		# Initialize our probability class
		self.prob_class = ProbabilityClass(self.target_eval_dict,
			self.interim_eval_dict,self.lens_params_train,self.lens_params_test)

		# If a baobab config path was provided for the test set we will extract
		# the true values of the hyperparameters from it
		if target_baobab_omega_path is not None:
			temp_config = configs.BaobabConfig.from_file(
				target_baobab_omega_path)
			temp_eval_dict = build_eval_dict(temp_config,self.lens_params_test,
				baobab_config=True)
			# Go through the target_eval_dict and extract the true values
			# from the temp_eval_dict (i.e. the eval dict generated by the
			# baobab config used to make the test set).
			self.true_hyp_values = []
			for name in self.target_eval_dict['hyp_names']:
				temp_index = temp_eval_dict['hyp_names'].index(name)
				self.true_hyp_values.append(temp_eval_dict['hyp_values'][
					temp_index])
		else:
			self.true_hyp_values = None