def test_construct_precision_matrix(self):
		# A couple of test cases to make sure that the generalized precision
		# matrix code works as expected.

		num_params = 4
		flip_pairs = []
		loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)

		# Set up a fake l matrix with elements
		l_mat_elements = np.array([[1,2,3,4,5,6,7,8,9,10]],dtype=float)
		l_mat = np.array([[np.exp(1),0,0,0],[2,np.exp(3),0,0],[4,5,np.exp(6),0],
			[7,8,9,np.exp(10)]])
		prec_mat = np.matmul(l_mat,l_mat.T)

		# Get the tf representation of the prec matrix
		l_mat_elements_tf = tf.constant(l_mat_elements)
		p_mat_tf, diag_tf, L_mat = loss_class.construct_precision_matrix(
			l_mat_elements_tf)

		# Make sure everything matches
		np.testing.assert_almost_equal(p_mat_tf.numpy()[0],prec_mat,decimal=5)
		diag_elements = np.array([1,3,6,10])
		np.testing.assert_almost_equal(diag_tf.numpy()[0],diag_elements)
		for pi, p_mat_np in enumerate(p_mat_tf.numpy()):
			np.testing.assert_almost_equal(p_mat_np,np.dot(
				L_mat.numpy()[pi],L_mat.numpy()[pi].T))

		# Rinse and repeat for a different number of elements with batching
		num_params = 3
		flip_pairs = []
		loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)

		# Set up a fake l matrix with elements
		l_mat_elements = np.array([[1,2,3,4,5,6],[1,2,3,4,5,6]],dtype=float)
		l_mat = np.array([[np.exp(1),0,0],[2,np.exp(3),0],[4,5,np.exp(6)]])
		prec_mat = np.matmul(l_mat,l_mat.T)

		# Get the tf representation of the prec matrix
		l_mat_elements_tf = tf.constant(l_mat_elements)
		p_mat_tf, diag_tf, _ = loss_class.construct_precision_matrix(
			l_mat_elements_tf)

		# Make sure everything matches
		for p_mat in p_mat_tf.numpy():
			np.testing.assert_almost_equal(p_mat,prec_mat)
		diag_elements = np.array([1,3,6])
		for diag in diag_tf.numpy():
			np.testing.assert_almost_equal(diag,diag_elements)
	def test_mse_loss(self):
		# Test that for a variety of number of parameters and bnn types, the
		# algorithm always returns the MSE loss.
		flip_pairs = []
		for num_params in range(1,20):
			# Diagonal covariance
			loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)
			y_true = np.random.randn(num_params).reshape(1,-1)
			y_pred = np.random.randn(num_params*2).reshape(1,-1)
			mse_tensor = loss_class.mse_loss(tf.constant(y_true,dtype=tf.float32),
				tf.constant(y_pred,dtype=tf.float32))
			self.assertAlmostEqual(mse_tensor.numpy()[0],np.mean(np.square(
				y_true-y_pred[:,:num_params])),places=5)

			# Full covariance
			y_true = np.random.randn(num_params).reshape(1,-1)
			y_pred = np.random.randn(int(num_params*(num_params+1)/2)).reshape(
				1,-1)
			mse_tensor = loss_class.mse_loss(tf.constant(y_true,dtype=tf.float32),
				tf.constant(y_pred,dtype=tf.float32))
			self.assertAlmostEqual(mse_tensor.numpy()[0],np.mean(np.square(
				y_true-y_pred[:,:num_params])),places=5)

			# GMM two matrices full covariance
			y_true = np.random.randn(num_params).reshape(1,-1)
			y_pred = np.random.randn(2*(num_params + int(
				num_params*(num_params+1)/2))+1).reshape(1,-1)
			mse_tensor = loss_class.mse_loss(tf.constant(y_true,dtype=tf.float32),
				tf.constant(y_pred,dtype=tf.float32))
			self.assertAlmostEqual(mse_tensor.numpy()[0],np.mean(np.square(
				y_true-y_pred[:,:num_params])),places=5)

		# Now an explicit test that flip_pairs is working
		flip_pairs = [[1,2]]
		num_params = 5
		loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)
		y_true = np.ones((4,num_params))
		y_pred = np.ones((4,num_params))
		y_pred[:,1:3] *= -1
		mse_tensor = loss_class.mse_loss(tf.constant(y_true,dtype=tf.float32),
			tf.constant(y_pred,dtype=tf.float32))
		self.assertEqual(np.sum(mse_tensor.numpy()),0)

		# Make sure flipping other pairs does not return 0
		y_pred[:,4] *= -1
		mse_tensor = loss_class.mse_loss(tf.constant(y_true,dtype=tf.float32),
			tf.constant(y_pred,dtype=tf.float32))
		self.assertGreater(np.sum(mse_tensor.numpy()),0.1)
	def test_log_gauss_full(self):
		# Will not be used for this test, but must be passed in.
		flip_pairs = []
		for num_params in range(1,10):
			# Pick a random true, pred, and std and make sure it agrees with the
			# scipy calculation
			loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)
			y_true = np.random.randn(num_params)
			y_pred = np.random.randn(num_params)

			l_mat_elements_tf = tf.constant(
				np.expand_dims(np.random.randn(int(num_params*(num_params+1)/2)),
					axis=0),dtype=tf.float32)

			p_mat_tf, L_diag, _ = loss_class.construct_precision_matrix(
				l_mat_elements_tf)

			p_mat = p_mat_tf.numpy()[0]

			nlp_tensor = loss_class.log_gauss_full(tf.constant(np.expand_dims(
				y_true,axis=0),dtype=float),tf.constant(np.expand_dims(
				y_pred,axis=0),dtype=float),p_mat_tf,L_diag)

			# Compare to scipy function to be exact. Add 2 pi offset.
			scipy_nlp = (-multivariate_normal.logpdf(y_true,y_pred,np.linalg.inv(
				p_mat)) - np.log(2 * np.pi) * num_params/2)
			# The decimal error can be significant due to inverting the precision
			# matrix
			self.assertAlmostEqual(np.sum(nlp_tensor.numpy()),scipy_nlp,places=1)
	def test_log_gauss_diag(self):
		# Will not be used for this test, but must be passed in.
		flip_pairs = []
		for num_params in range(1,20):
			# Pick a random true, pred, and std and make sure it agrees with the
			# scipy calculation
			loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)
			y_true = np.random.randn(num_params)
			y_pred = np.random.randn(num_params)
			std_pred = np.random.randn(num_params)
			nlp_tensor = loss_class.log_gauss_diag(tf.constant(y_true),
				tf.constant(y_pred),tf.constant(std_pred))

			# Compare to scipy function to be exact. Add 2 pi offset.
			scipy_nlp = -multivariate_normal.logpdf(y_true,y_pred,
				np.diag(np.exp(std_pred))) - np.log(2 * np.pi) * num_params/2
			self.assertAlmostEqual(nlp_tensor.numpy(),scipy_nlp)
	def test_log_gauss_gm_full(self):
		# Will not be used for this test, but must be passed in.
		flip_pairs = []
		for num_params in range(1,10):
			# Pick a random true, pred, and std and make sure it agrees with the
			# scipy calculation
			loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)
			y_true = np.random.randn(num_params)
			yttf=tf.constant(np.expand_dims(y_true,axis=0),dtype=float)
			y_pred1 = np.random.randn(num_params)
			yp1tf=tf.constant(np.expand_dims(y_pred1,axis=0),dtype=float)
			y_pred2 = np.random.randn(num_params)
			yp2tf=tf.constant(np.expand_dims(y_pred2,axis=0),dtype=float)
			pi = np.random.rand()
			pitf = tf.constant(np.array([[pi]]),dtype=float)

			l_mat_elements_tf1 = tf.constant(
				np.expand_dims(np.random.randn(int(num_params*(num_params+1)/2)),
					axis=0),dtype=tf.float32)
			l_mat_elements_tf2 = tf.constant(
				np.expand_dims(np.random.randn(int(num_params*(num_params+1)/2)),
					axis=0),dtype=tf.float32)

			p_mat_tf1, L_diag1, _ = loss_class.construct_precision_matrix(
				l_mat_elements_tf1)
			p_mat_tf2, L_diag2, _ = loss_class.construct_precision_matrix(
				l_mat_elements_tf2)

			cov_mat1 = np.linalg.inv(p_mat_tf1.numpy()[0])
			cov_mat2 = np.linalg.inv(p_mat_tf2.numpy()[0])

			nlp_tensor = loss_class.log_gauss_gm_full(yttf,[yp1tf,yp2tf],
				[p_mat_tf1,p_mat_tf2],[L_diag1,L_diag2],[pitf,1-pitf])

			# Compare to scipy function to be exact. Add 2 pi offset.
			scipy_nlp1 = (multivariate_normal.logpdf(y_true,y_pred1,cov_mat1)
				+ np.log(2 * np.pi) * num_params/2 + np.log(pi))
			scipy_nlp2 = (multivariate_normal.logpdf(y_true,y_pred2,cov_mat2)
				+ np.log(2 * np.pi) * num_params/2 + np.log(1-pi))
			scipy_nlp = -np.logaddexp(scipy_nlp1,scipy_nlp2)
			# The decimal error can be significant due to inverting the precision
			# matrix
			self.assertAlmostEqual(np.sum(nlp_tensor.numpy()),scipy_nlp,places=1)
    def test_gen_samples_full(self):

        self.infer_class = bnn_inference.InferenceClass(self.cfg)
        # Delete the tf record file made during the initialization of the
        # inference class.
        os.remove(self.root_path + 'tf_record_test_val')
        os.remove(self.root_path + 'new_metadata.csv')
        # Get rid of the normalization file.
        os.remove(self.normalization_constants_path)

        # First we have to make a fake model whose statistics are very well
        # defined.

        class ToyModel():
            def __init__(self, mean, covariance, batch_size, L_elements):
                # We want to make sure our performance is consistent for a
                # test
                np.random.seed(6)
                self.mean = mean
                self.num_params = len(mean)
                self.covariance = covariance
                self.batch_size = batch_size
                self.L_elements = L_elements
                self.L_elements_len = int(self.num_params *
                                          (self.num_params + 1) / 2)

            def predict(self, image):
                # We won't actually be using the image. We just want it for
                # testing.
                return tf.constant(
                    np.concatenate([
                        np.zeros(
                            (self.batch_size, self.num_params)) + self.mean,
                        np.zeros((self.batch_size, self.L_elements_len)) +
                        self.L_elements
                    ],
                                   axis=-1), tf.float32)

        # Start with a simple covariance matrix example.
        mean = np.ones(self.num_params) * 2
        covariance = np.diag(np.ones(self.num_params) * 0.000001)
        L_elements = np.array([np.log(1)] * self.num_params +
                              [0] * int(self.num_params *
                                        (self.num_params - 1) / 2))
        full_model = ToyModel(mean, covariance, self.batch_size, L_elements)

        # We don't want any flipping going on
        self.infer_class.flip_mat_list = [np.diag(np.ones(self.num_params))]

        # Create tf record. This won't be used, but it has to be there for
        # the function to be able to pull some images.
        # Make fake norms data
        fake_norms = {}
        for lens_param in self.lens_params:
            fake_norms[lens_param] = np.array([0.0, 1.0])
        fake_norms = pd.DataFrame(data=fake_norms)
        fake_norms.to_csv(self.normalization_constants_path, index=False)
        data_tools.generate_tf_record(self.root_path, self.lens_params,
                                      self.lens_params_path,
                                      self.tf_record_path)

        # Replace the real model with our fake model and generate samples
        self.infer_class.model = full_model
        self.infer_class.bnn_type = 'full'
        # self.infer_class.gen_samples(1000)

        # # Make sure these samples follow the required statistics.
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred-mean)),
        # 	0,places=1)
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_std-1)),0,
        # 	places=1)
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_cov-np.eye(
        # 	self.num_params))),0,places=1)
        # self.assertTupleEqual(self.infer_class.al_cov.shape,(self.batch_size,
        # 	self.num_params,self.num_params))
        # self.assertAlmostEqual(np.mean(np.abs(self.infer_class.al_cov-np.eye(
        # 	self.num_params))),0)

        mean = np.zeros(self.num_params)
        loss_class = bnn_alexnet.LensingLossFunctions([], self.num_params)
        L_elements = np.ones((1, len(L_elements))) * 0.2
        full_model = ToyModel(mean, covariance, self.batch_size, L_elements)
        self.infer_class.model = full_model
        self.infer_class.gen_samples(1000)

        # Calculate the corresponding covariance matrix
        _, _, L_mat = loss_class.construct_precision_matrix(
            tf.constant(L_elements))
        L_mat = np.linalg.inv(L_mat.numpy()[0].T)
        cov_mat = np.dot(L_mat, L_mat.T)

        # Make sure these samples follow the required statistics.
        self.assertAlmostEqual(np.mean(np.abs(self.infer_class.y_pred - mean)),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs(self.infer_class.y_std - np.sqrt(np.diag(cov_mat)))),
                               0,
                               places=1)
        self.assertAlmostEqual(np.mean(
            np.abs((self.infer_class.y_cov - cov_mat))),
                               0,
                               places=1)
        self.assertTupleEqual(
            self.infer_class.al_cov.shape,
            (self.batch_size, self.num_params, self.num_params))
        self.assertAlmostEqual(
            np.mean(np.abs(self.infer_class.al_cov - cov_mat)), 0)

        # Clean up the files we generated
        os.remove(self.normalization_constants_path)
        os.remove(self.tf_record_path)
	def test_gm_full_covariance_loss(self):
		# Test that the diagonal covariance loss gives the correct values
		flip_pairs = [[1,2],[3,4],[1,2,3,4]]
		num_params = 6
		loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)

		# Set up a couple of test function to make sure that the minimum loss
		# is taken
		y_true = np.ones((1,num_params))
		y_pred = np.ones((1,num_params))
		y_pred1 = np.ones((1,num_params))
		y_pred1[:,[1,2]] = -1
		y_pred2 = np.ones((1,num_params))
		y_pred2[:,[3,4]] = -1
		y_pred3 = np.ones((1,num_params))
		y_pred3[:,[1,2,3,4]] = -1
		y_preds = [y_pred,y_pred1,y_pred2,y_pred3]
		L_elements_len = int(num_params*(num_params+1)/2)
		# Have to keep this matrix simple so that we still get a reasonable
		# answer when we invert it for scipy check
		L_elements = np.zeros((1,L_elements_len))+1e-2
		pi_logit = 2
		pi = np.exp(pi_logit)/(np.exp(pi_logit)+1)
		pi_arr = np.array([[pi_logit]])

		# Get out the covariance matrix in numpy
		l_mat_elements_tf = tf.constant(L_elements,dtype=tf.float32)
		p_mat_tf, L_diag, _ = loss_class.construct_precision_matrix(
			l_mat_elements_tf)
		cov_mat = np.linalg.inv(p_mat_tf.numpy()[0])

		scipy_nlp1 = (multivariate_normal.logpdf(y_true[0],y_pred[0],cov_mat)
			+ np.log(2 * np.pi) * num_params/2 + np.log(pi))
		scipy_nlp2 = (multivariate_normal.logpdf(y_true[0],y_pred[0],cov_mat)
			+ np.log(2 * np.pi) * num_params/2 + np.log(1-pi))
		scipy_nlp = -np.logaddexp(scipy_nlp1,scipy_nlp2)

		for yp1 in y_preds:
			for yp2 in y_preds:
				yptf = tf.constant(np.concatenate([yp1,L_elements,yp2,L_elements,
					pi_arr],axis=-1),dtype=tf.float32)
				yttf = tf.constant(y_true,dtype=tf.float32)
				diag_loss = loss_class.gm_full_covariance_loss(yttf,yptf)

				self.assertAlmostEqual(np.sum(diag_loss.numpy()),
					scipy_nlp,places=4)

		# Repeat this excercise, but introducing error in prediction
		for yp in y_preds:
			yp[:,0] = 10
		scipy_nlp1 = (multivariate_normal.logpdf(y_true[0],y_pred[0],cov_mat)
			+ np.log(2 * np.pi) * num_params/2 + np.log(pi))
		scipy_nlp2 = (multivariate_normal.logpdf(y_true[0],y_pred[0],cov_mat)
			+ np.log(2 * np.pi) * num_params/2 + np.log(1-pi))
		scipy_nlp = -np.logaddexp(scipy_nlp1,scipy_nlp2)

		for yp1 in y_preds:
			for yp2 in y_preds:
				yptf = tf.constant(np.concatenate([yp1,L_elements,yp2,L_elements,
					pi_arr],axis=-1),dtype=tf.float32)
				yttf = tf.constant(y_true,dtype=tf.float32)
				diag_loss = loss_class.gm_full_covariance_loss(yttf,yptf)

				self.assertAlmostEqual(np.sum(diag_loss.numpy()),
					scipy_nlp,places=4)

		# Confirm that when the wrong pair is flipped, it does not
		# return the same answer.
		y_pred4 = np.ones((1,num_params)); y_pred4[:,[5,2]] = -1
		y_pred4[:,0] = 10
		yptf = tf.constant(np.concatenate([y_pred4,L_elements,y_pred,L_elements,
			pi_arr],axis=-1),dtype=tf.float32)
		yttf = tf.constant(y_true,dtype=tf.float32)
		diag_loss = loss_class.gm_full_covariance_loss(yttf,yptf)

		self.assertGreater(np.abs(diag_loss.numpy()-scipy_nlp),0.1)

		# Finally, confirm that batching works
		single_batch1 = np.concatenate([y_pred2,L_elements,y_pred,L_elements,
			pi_arr],axis=-1)
		single_batch2 = np.concatenate([y_pred3,L_elements,y_pred,L_elements,
			pi_arr],axis=-1)
		yptf = tf.constant(np.concatenate([single_batch1,single_batch2],axis=0),
			dtype=tf.float32)
		self.assertEqual(yptf.shape,[2,55])
		diag_loss = loss_class.gm_full_covariance_loss(yttf,yptf).numpy()
		self.assertEqual(diag_loss.shape,(2,))
		self.assertEqual(diag_loss[0],diag_loss[1])
		self.assertAlmostEqual(diag_loss[0],scipy_nlp,places=4)
	def test_full_covariance_loss(self):
		# Test that the diagonal covariance loss gives the correct values
		flip_pairs = [[1,2],[3,4],[1,2,3,4]]
		num_params = 6
		loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)

		# Set up a couple of test function to make sure that the minimum loss
		# is taken
		y_true = np.ones((1,num_params))
		y_pred = np.ones((1,num_params))
		y_pred1 = np.ones((1,num_params)); y_pred1[:,[1,2]] = -1
		y_pred2 = np.ones((1,num_params)); y_pred2[:,[3,4]] = -1
		y_pred3 = np.ones((1,num_params)); y_pred3[:,[1,2,3,4]] = -1
		y_preds = [y_pred,y_pred1,y_pred2,y_pred3]
		L_elements_len = int(num_params*(num_params+1)/2)
		# Have to keep this matrix simple so that we still get a reasonable
		# answer when we invert it for scipy check
		L_elements = np.zeros((1,L_elements_len))+1e-2

		# Get out the covariance matrix in numpy
		l_mat_elements_tf = tf.constant(L_elements,dtype=tf.float32)
		p_mat_tf, L_diag, _ = loss_class.construct_precision_matrix(
			l_mat_elements_tf)
		cov_mat = np.linalg.inv(p_mat_tf.numpy()[0])

		# The correct value of the nlp
		scipy_nlp = -multivariate_normal.logpdf(y_true.flatten(),y_pred.flatten(),
			cov_mat) -np.log(2 * np.pi)*num_params/2

		for yp in y_preds:
			yptf = tf.constant(np.concatenate([yp,L_elements],axis=-1),
				dtype=tf.float32)
			yttf = tf.constant(y_true,dtype=tf.float32)
			diag_loss = loss_class.full_covariance_loss(yttf,yptf)

			self.assertAlmostEqual(np.sum(diag_loss.numpy()),scipy_nlp,places=4)

		# Repeat this excercise, but introducing error in prediction
		for yp in y_preds:
			yp[:,0] = 10
		scipy_nlp = -multivariate_normal.logpdf(y_true.flatten(),y_pred.flatten(),
			cov_mat) -np.log(2 * np.pi)*num_params/2

		for yp in y_preds:
			yptf = tf.constant(np.concatenate([yp,L_elements],axis=-1),
				dtype=tf.float32)
			yttf = tf.constant(y_true,dtype=tf.float32)
			diag_loss = loss_class.full_covariance_loss(yttf,yptf)

			self.assertAlmostEqual(np.sum(diag_loss.numpy()),scipy_nlp,places=4)

		# Confirm that when the wrong pair is flipped, it does not
		# return the same answer.
		y_pred4 = np.ones((1,num_params)); y_pred4[:,[5,2]] = -1
		y_pred4[:,0] = 10
		yptf = tf.constant(np.concatenate([y_pred4,L_elements],axis=-1),
				dtype=tf.float32)
		yttf = tf.constant(y_true,dtype=tf.float32)
		diag_loss = loss_class.full_covariance_loss(yttf,yptf)

		self.assertGreater(np.abs(diag_loss.numpy()-scipy_nlp),1)

		# Make sure it is still consistent with the true nlp
		scipy_nlp = -multivariate_normal.logpdf(y_true.flatten(),
			y_pred4.flatten(),cov_mat) -np.log(2 * np.pi)*num_params/2
		self.assertAlmostEqual(np.sum(diag_loss.numpy()),scipy_nlp,places=2)

		# Finally, confirm that batching works
		yptf = tf.constant(np.concatenate(
			[np.concatenate([y_pred,L_elements],axis=-1),
			np.concatenate([y_pred1,L_elements],axis=-1)],axis=0),
			dtype=tf.float32)
		self.assertEqual(yptf.shape,[2,27])
		diag_loss = loss_class.full_covariance_loss(yttf,yptf).numpy()
		self.assertEqual(diag_loss.shape,(2,))
		self.assertEqual(diag_loss[0],diag_loss[1])
	def test_diagonal_covariance_loss(self):
		# Test that the diagonal covariance loss gives the correct values
		flip_pairs = [[1,2],[3,4],[1,2,3,4]]
		num_params = 6
		loss_class = bnn_alexnet.LensingLossFunctions(flip_pairs,num_params)

		# Set up a couple of test function to make sure that the minimum loss
		# is taken
		y_true = np.ones((1,num_params))
		y_pred = np.ones((1,num_params))
		y_pred1 = np.ones((1,num_params)); y_pred1[:,[1,2]] = -1
		y_pred2 = np.ones((1,num_params)); y_pred2[:,[3,4]] = -1
		y_pred3 = np.ones((1,num_params)); y_pred3[:,[1,2,3,4]] = -1
		y_preds = [y_pred,y_pred1,y_pred2,y_pred3]
		std_pred = np.ones((1,num_params))

		# The correct value of the nlp
		scipy_nlp = -multivariate_normal.logpdf(y_true.flatten(),y_pred.flatten(),
			np.diag(np.exp(std_pred.flatten()))) -np.log(2 * np.pi)*num_params/2

		for yp in y_preds:
			yptf = tf.constant(np.concatenate([yp,std_pred],axis=-1),
				dtype=tf.float32)
			yttf = tf.constant(y_true,dtype=tf.float32)
			diag_loss = loss_class.diagonal_covariance_loss(yttf,yptf)

			self.assertAlmostEqual(diag_loss.numpy(),scipy_nlp)

		# Repeat this excercise, but introducing error in prediction
		for yp in y_preds:
			yp[:,0] = 10
		scipy_nlp = -multivariate_normal.logpdf(y_true.flatten(),y_pred.flatten(),
			np.diag(np.exp(std_pred.flatten()))) -np.log(2 * np.pi)*num_params/2

		for yp in y_preds:
			yptf = tf.constant(np.concatenate([yp,std_pred],axis=-1),
				dtype=tf.float32)
			yttf = tf.constant(y_true,dtype=tf.float32)
			diag_loss = loss_class.diagonal_covariance_loss(yttf,yptf)

			self.assertAlmostEqual(diag_loss.numpy(),scipy_nlp)

		# Confirm that when the wrong pair is flipped, it does not
		# return the same answer.
		y_pred4 = np.ones((1,num_params))
		y_pred4[:,[5,2]] = -1
		y_pred4[:,0] = 10
		yptf = tf.constant(np.concatenate([y_pred4,std_pred],axis=-1),
				dtype=tf.float32)
		yttf = tf.constant(y_true,dtype=tf.float32)
		diag_loss = loss_class.diagonal_covariance_loss(yttf,yptf)

		self.assertGreater(np.abs(diag_loss.numpy()-scipy_nlp),1)

		# Make sure it is still consistent with the true nlp
		scipy_nlp = -multivariate_normal.logpdf(y_true.flatten(),
			y_pred4.flatten(),
			np.diag(np.exp(std_pred.flatten()))) -np.log(2 * np.pi)*num_params/2
		self.assertAlmostEqual(diag_loss.numpy(),scipy_nlp)

		# Finally, confirm that batching works
		yptf = tf.constant(np.concatenate(
			[np.concatenate([y_pred,std_pred],axis=-1),
			np.concatenate([y_pred1,std_pred],axis=-1)],axis=0),dtype=tf.float32)
		self.assertEqual(yptf.shape,[2,12])
		diag_loss = loss_class.diagonal_covariance_loss(yttf,yptf).numpy()
		self.assertEqual(diag_loss.shape,(2,))
		self.assertEqual(diag_loss[0],diag_loss[1])
Beispiel #10
0
    def __init__(self, cfg, lite_class=False, test_set_path=None):
        """
		Initialize the InferenceClass instance using the parameters of the
		configuration file.

		Parameters:
		cfg (dict): The dictionary attained from reading the json config file.
		lite_class (bool): If True, do not bother loading the BNN model weights.
			This allows the user to save on memory, but will cause an error
			if the BNN samples have not already been drawn.
		test_set_path (str): The path to the set of images that the
			forward modeling image will be pulled from. If None, the path to
			the validation set images will be used.
		"""

        self.cfg = cfg

        # Replace the validation path with the test_set_path if specified
        if test_set_path is not None:
            self.cfg['validation_params']['root_path'] = test_set_path

        self.lite_class = lite_class
        if self.lite_class:
            self.model = None
            self.loss = None
        else:
            self.model, self.loss = model_trainer.model_loss_builder(
                cfg, verbose=True)

        # Load the validation set we're going to use.
        self.tf_record_path_v = os.path.join(
            cfg['validation_params']['root_path'],
            cfg['validation_params']['tf_record_path'])
        # Load the parameters and the batch size needed for computation
        self.final_params = cfg['training_params']['final_params']
        self.final_params_print_names = cfg['inference_params'][
            'final_params_print_names']
        self.num_params = len(self.final_params)
        self.batch_size = cfg['training_params']['batch_size']
        self.norm_images = cfg['training_params']['norm_images']
        self.baobab_config_path = cfg['training_params']['baobab_config_path']

        if not os.path.exists(self.tf_record_path_v):
            print('Generating new TFRecord at %s' % (self.tf_record_path_v))
            model_trainer.prepare_tf_record(
                cfg,
                cfg['validation_params']['root_path'],
                self.tf_record_path_v,
                self.final_params,
                train_or_test='test')
        else:
            print('TFRecord found at %s' % (self.tf_record_path_v))

        self.tf_dataset_v = data_tools.build_tf_dataset(
            self.tf_record_path_v,
            self.final_params,
            self.batch_size,
            1,
            self.baobab_config_path,
            norm_images=self.norm_images)

        self.bnn_type = cfg['training_params']['bnn_type']

        # This code is borrowed from the LensingLossFunctions initializer
        self.flip_pairs = cfg['training_params']['flip_pairs']
        # Always include no flips as an option.
        self.flip_mat_list = [np.diag(np.ones(self.num_params))]
        for flip_pair in self.flip_pairs:
            # Initialize a numpy array since this is the easiest way
            # to flexibly set the tensor.
            const_initializer = np.ones(self.num_params)
            const_initializer[flip_pair] = -1
            self.flip_mat_list.append(np.diag(const_initializer))

        self.loss_class = bnn_alexnet.LensingLossFunctions(
            self.flip_pairs, self.num_params)

        self.y_pred = None
        self.y_cov = None
        self.y_std = None
        self.y_test = None
        self.predict_samps = None
        self.samples_init = False