示例#1
0
 def _assert_valid_sample(self, x):
     if not self.validate_args:
         return x
     return control_flow_ops.with_dependencies([
         tf.assert_non_positive(x),
         distribution_util.assert_close(tf.zeros([], dtype=self.dtype),
                                        tf.reduce_logsumexp(x, axis=[-1])),
     ], x)
示例#2
0
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_positive(x),
       distribution_util.assert_close(
           array_ops.zeros([], dtype=self.dtype),
           math_ops.reduce_logsumexp(x, axis=[-1])),
   ], x)
示例#3
0
 def _maybe_assert_valid_sample(self, x):
     """Checks the validity of a sample."""
     if not self.validate_args:
         return x
     return control_flow_ops.with_dependencies([
         check_ops.assert_positive(x, message="samples must be positive"),
         distribution_util.assert_close(
             array_ops.ones([], dtype=self.dtype),
             math_ops.reduce_sum(x, -1),
             message="sample last-dimension must sum to `1`"),
     ], x)
示例#4
0
    def testAssertCloseEpsilon(self):
        x = [0., 5, 10, 15, 20]
        # x != y
        y = [0.1, 5, 10, 15, 20]
        # x = z
        z = [1e-8, 5, 10, 15, 20]
        with self.test_session():
            with ops.control_dependencies(
                [distribution_util.assert_close(x, z)]):
                array_ops.identity(x).eval()

            with self.assertRaisesOpError("Condition x ~= y"):
                with ops.control_dependencies(
                    [distribution_util.assert_close(x, y)]):
                    array_ops.identity(x).eval()

            with self.assertRaisesOpError("Condition x ~= y"):
                with ops.control_dependencies(
                    [distribution_util.assert_close(y, z)]):
                    array_ops.identity(y).eval()
示例#5
0
 def _maybe_assert_valid_sample(self, x):
   """Checks the validity of a sample."""
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       check_ops.assert_positive(
           x,
           message="samples must be positive"),
       distribution_util.assert_close(
           array_ops.ones([], dtype=self.dtype),
           math_ops.reduce_sum(x, -1),
           message="sample last-dimension must sum to `1`"),
   ], x)
示例#6
0
    def testAssertCloseNonIntegerDtype(self):
        x = array_ops.placeholder(dtypes.float32)
        y = x + 1e-8
        z = array_ops.placeholder(dtypes.float32)
        feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]}
        with self.test_session():
            with ops.control_dependencies(
                [distribution_util.assert_close(x, y)]):
                array_ops.identity(x).eval(feed_dict=feed_dict)

            with ops.control_dependencies(
                [distribution_util.assert_close(y, x)]):
                array_ops.identity(x).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("Condition x ~= y"):
                with ops.control_dependencies(
                    [distribution_util.assert_close(x, z)]):
                    array_ops.identity(x).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("Condition x ~= y"):
                with ops.control_dependencies(
                    [distribution_util.assert_close(y, z)]):
                    array_ops.identity(y).eval(feed_dict=feed_dict)