def testHandlesNanFromKinetic(self): x = [1, np.inf, -np.inf, np.nan] momentums, proposed_momentums = [[np.reshape(self.dtype(x), [-1, 1])] for x in np.meshgrid(x, x)] num_chains = len(momentums[0]) momentums = [tf.convert_to_tensor(momentums[0])] proposed_momentums = [tf.convert_to_tensor(proposed_momentums[0])] log_acceptance_correction = _compute_log_acceptance_correction( momentums, proposed_momentums, independent_chain_ndims=1) grads = tf.gradients(log_acceptance_correction, momentums) [actual_log_acceptance_correction, grads_] = self.evaluate([log_acceptance_correction, grads]) # Ensure log_acceptance_correction is `inf` (note: that's positive inf) in # weird cases and finite otherwise. expected_log_acceptance_correction = -(self.dtype([0] + [np.inf] * (num_chains - 1))) self.assertAllEqual(expected_log_acceptance_correction, actual_log_acceptance_correction) # Ensure gradient is finite. g = grads_[0].reshape([len(x), len(x)])[:, 0] self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) # The remaining gradients are nan because the momentum was itself nan or # inf. g = grads_[0].reshape([len(x), len(x)])[:, 1:] self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
def testHandlesNanFromKinetic(self): x = [1, np.inf, -np.inf, np.nan] momentums, proposed_momentums = [ [np.reshape(self.dtype(x), [-1, 1])] for x in np.meshgrid(x, x)] num_chains = len(momentums[0]) momentums = [tf.convert_to_tensor(momentums[0])] proposed_momentums = [tf.convert_to_tensor(proposed_momentums[0])] log_acceptance_correction = _compute_log_acceptance_correction( momentums, proposed_momentums, independent_chain_ndims=1) grads = tf.gradients(log_acceptance_correction, momentums) [actual_log_acceptance_correction, grads_] = self.evaluate([ log_acceptance_correction, grads]) # Ensure log_acceptance_correction is `inf` (note: that's positive inf) in # weird cases and finite otherwise. expected_log_acceptance_correction = -( self.dtype([0] + [np.inf]*(num_chains - 1))) self.assertAllEqual(expected_log_acceptance_correction, actual_log_acceptance_correction) # Ensure gradient is finite. g = grads_[0].reshape([len(x), len(x)])[:, 0] self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) # The remaining gradients are nan because the momentum was itself nan or # inf. g = grads_[0].reshape([len(x), len(x)])[:, 1:] self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))