def _assert_valid_sample(self, x): if not self.validate_args: return x return control_flow_ops.with_dependencies([ tf.assert_non_positive(x), distribution_util.assert_close(tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])), ], x)
def _assert_valid_sample(self, x): if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), distribution_util.assert_close( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x)
def _maybe_assert_valid_sample(self, x): """Checks the validity of a sample.""" if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_positive(x, message="samples must be positive"), distribution_util.assert_close( array_ops.ones([], dtype=self.dtype), math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), ], x)
def testAssertCloseEpsilon(self): x = [0., 5, 10, 15, 20] # x != y y = [0.1, 5, 10, 15, 20] # x = z z = [1e-8, 5, 10, 15, 20] with self.test_session(): with ops.control_dependencies( [distribution_util.assert_close(x, z)]): array_ops.identity(x).eval() with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies( [distribution_util.assert_close(x, y)]): array_ops.identity(x).eval() with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies( [distribution_util.assert_close(y, z)]): array_ops.identity(y).eval()
def _maybe_assert_valid_sample(self, x): """Checks the validity of a sample.""" if not self.validate_args: return x return control_flow_ops.with_dependencies([ check_ops.assert_positive( x, message="samples must be positive"), distribution_util.assert_close( array_ops.ones([], dtype=self.dtype), math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), ], x)
def testAssertCloseNonIntegerDtype(self): x = array_ops.placeholder(dtypes.float32) y = x + 1e-8 z = array_ops.placeholder(dtypes.float32) feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]} with self.test_session(): with ops.control_dependencies( [distribution_util.assert_close(x, y)]): array_ops.identity(x).eval(feed_dict=feed_dict) with ops.control_dependencies( [distribution_util.assert_close(y, x)]): array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies( [distribution_util.assert_close(x, z)]): array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("Condition x ~= y"): with ops.control_dependencies( [distribution_util.assert_close(y, z)]): array_ops.identity(y).eval(feed_dict=feed_dict)