Exemplo n.º 1
0
 def test_broadcast_success(self):
     self.assertAllEqual(
         tf.zeros([10, 2]),
         ps.smart_where(tf.constant([True, True]),
                        lambda: tf.zeros([10, 1]), lambda: None))
     self.assertAllEqual(
         tf.ones([2, 10]),
         ps.smart_where(tf.constant([[False], [False]]), lambda: None,
                        lambda: tf.ones([10])))
Exemplo n.º 2
0
 def _survival_function(self, y, **kwargs):
   if not self.bijector._is_injective:  # pylint: disable=protected-access
     raise NotImplementedError('`survival_function` is not implemented when '
                               '`bijector` is not injective.')
   distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
   x = self.bijector.inverse(y, **bijector_kwargs)
   # TODO(b/141130733): Check/fix any gradient numerics issues.
   return ps.smart_where(
       self.bijector._internal_is_increasing(**bijector_kwargs),  # pylint: disable=protected-access
       lambda: self.distribution.survival_function(x, **distribution_kwargs),
       lambda: self.distribution.cdf(x, **distribution_kwargs))
Exemplo n.º 3
0
 def _quantile(self, value, **kwargs):
     if not self.bijector._is_injective:  # pylint: disable=protected-access
         raise NotImplementedError('`quantile` is not implemented when '
                                   '`bijector` is not injective.')
     distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
     value = ps.smart_where(
         self.bijector._internal_is_increasing(**bijector_kwargs),  # pylint: disable=protected-access
         lambda: value,
         lambda: 1 - value)
     # x_q is the "qth quantile" of X iff q = P[X <= x_q].  Now, since X =
     # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)],
     # implies the qth quantile of Y is g(x_q).
     inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
     return self.bijector.forward(inv_cdf, **bijector_kwargs)
 def _survival_function(self, y, **kwargs):
   if self._is_maybe_event_override:
     raise NotImplementedError("survival_function is not implemented when "
                               "overriding event_shape")
   if not self.bijector._is_injective:  # pylint: disable=protected-access
     raise NotImplementedError("survival_function is not implemented when "
                               "bijector is not injective.")
   distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
   x = self.bijector.inverse(y, **bijector_kwargs)
   # TODO(b/141130733): Check/fix any gradient numerics issues.
   return prefer_static.smart_where(
       self.bijector._internal_is_increasing(**bijector_kwargs),  # pylint: disable=protected-access
       lambda: self.distribution.survival_function(x, **distribution_kwargs),
       lambda: self.distribution.cdf(x, **distribution_kwargs))
 def _quantile(self, value, **kwargs):
   if self._is_maybe_event_override:
     raise NotImplementedError("quantile is not implemented when overriding "
                               "event_shape")
   if not self.bijector._is_injective:  # pylint: disable=protected-access
     raise NotImplementedError("quantile is not implemented when "
                               "bijector is not injective.")
   distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
   value = prefer_static.smart_where(
       self.bijector._internal_is_increasing(**bijector_kwargs),  # pylint: disable=protected-access
       lambda: value,
       lambda: 1 - value)
   # x_q is the "qth quantile" of X iff q = P[X <= x_q].  Now, since X =
   # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)],
   # implies the qth quantile of Y is g(x_q).
   inv_cdf = self.distribution.quantile(value, **distribution_kwargs)
   return self.bijector.forward(inv_cdf, **bijector_kwargs)
Exemplo n.º 6
0
 def test_where_fallback(self):
     self.assertAllEqual([1., 0.],
                         ps.smart_where(tf.constant([True, False]),
                                        lambda: tf.ones([]),
                                        lambda: tf.zeros([])))
Exemplo n.º 7
0
 def test_cond_y_broadcast_error(self):
     with self.assertRaisesOpError('Incompatible shapes'):
         self.evaluate(
             ps.smart_where(tf.constant([False, False]), lambda: None,
                            lambda: tf.zeros([3])))
Exemplo n.º 8
0
    def test_static_scalar_condition(self):
        fn_calls = [0, 0]
        ones = tf.ones([10])
        zeros = tf.zeros([10])

        def fn1():
            fn_calls[0] += 1
            return ones

        def fn2():
            fn_calls[1] += 1
            return zeros

        self.assertAllEqual(zeros, ps.smart_where(False, fn1, fn2))
        self.assertEqual([0, 1], fn_calls)
        self.assertAllEqual(ones, ps.smart_where(True, fn1, fn2))
        self.assertEqual([1, 1], fn_calls)
        self.assertAllEqual(zeros, ps.smart_where(tf.constant(False), fn1,
                                                  fn2))
        self.assertEqual([1, 2], fn_calls)
        self.assertAllEqual(ones, ps.smart_where(tf.constant(True), fn1, fn2))
        self.assertEqual([2, 2], fn_calls)
        self.assertAllEqual(zeros, ps.smart_where(np.array(False), fn1, fn2))
        self.assertEqual([2, 3], fn_calls)
        self.assertAllEqual(ones, ps.smart_where(np.array(True), fn1, fn2))
        self.assertEqual([3, 3], fn_calls)

        self.assertAllEqual(zeros, ps.smart_where(0, fn1, fn2))
        self.assertEqual([3, 4], fn_calls)
        self.assertAllEqual(ones, ps.smart_where(1, fn1, fn2))
        self.assertEqual([4, 4], fn_calls)
        self.assertAllEqual(zeros, ps.smart_where(tf.constant(0), fn1, fn2))
        self.assertEqual([4, 5], fn_calls)
        self.assertAllEqual(ones, ps.smart_where(tf.constant(1), fn1, fn2))
        self.assertEqual([5, 5], fn_calls)
        self.assertAllEqual(zeros, ps.smart_where(np.array(0), fn1, fn2))
        self.assertEqual([5, 6], fn_calls)
        self.assertAllEqual(ones, ps.smart_where(np.array(1), fn1, fn2))
        self.assertEqual([6, 6], fn_calls)
Exemplo n.º 9
0
 def test_cond_x_broadcast_error(self):
     with self.assertRaisesOpError('Incompatible shapes'):
         self.evaluate(
             prefer_static.smart_where(tf.constant([True, True]),
                                       lambda: tf.zeros([3]), lambda: None))