def test_broadcast_success(self): self.assertAllEqual( tf.zeros([10, 2]), ps.smart_where(tf.constant([True, True]), lambda: tf.zeros([10, 1]), lambda: None)) self.assertAllEqual( tf.ones([2, 10]), ps.smart_where(tf.constant([[False], [False]]), lambda: None, lambda: tf.ones([10])))
def _survival_function(self, y, **kwargs): if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError('`survival_function` is not implemented when ' '`bijector` is not injective.') distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) x = self.bijector.inverse(y, **bijector_kwargs) # TODO(b/141130733): Check/fix any gradient numerics issues. return ps.smart_where( self.bijector._internal_is_increasing(**bijector_kwargs), # pylint: disable=protected-access lambda: self.distribution.survival_function(x, **distribution_kwargs), lambda: self.distribution.cdf(x, **distribution_kwargs))
def _quantile(self, value, **kwargs): if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError('`quantile` is not implemented when ' '`bijector` is not injective.') distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) value = ps.smart_where( self.bijector._internal_is_increasing(**bijector_kwargs), # pylint: disable=protected-access lambda: value, lambda: 1 - value) # x_q is the "qth quantile" of X iff q = P[X <= x_q]. Now, since X = # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)], # implies the qth quantile of Y is g(x_q). inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs)
def _survival_function(self, y, **kwargs): if self._is_maybe_event_override: raise NotImplementedError("survival_function is not implemented when " "overriding event_shape") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("survival_function is not implemented when " "bijector is not injective.") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) x = self.bijector.inverse(y, **bijector_kwargs) # TODO(b/141130733): Check/fix any gradient numerics issues. return prefer_static.smart_where( self.bijector._internal_is_increasing(**bijector_kwargs), # pylint: disable=protected-access lambda: self.distribution.survival_function(x, **distribution_kwargs), lambda: self.distribution.cdf(x, **distribution_kwargs))
def _quantile(self, value, **kwargs): if self._is_maybe_event_override: raise NotImplementedError("quantile is not implemented when overriding " "event_shape") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("quantile is not implemented when " "bijector is not injective.") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) value = prefer_static.smart_where( self.bijector._internal_is_increasing(**bijector_kwargs), # pylint: disable=protected-access lambda: value, lambda: 1 - value) # x_q is the "qth quantile" of X iff q = P[X <= x_q]. Now, since X = # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)], # implies the qth quantile of Y is g(x_q). inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs)
def test_where_fallback(self): self.assertAllEqual([1., 0.], ps.smart_where(tf.constant([True, False]), lambda: tf.ones([]), lambda: tf.zeros([])))
def test_cond_y_broadcast_error(self): with self.assertRaisesOpError('Incompatible shapes'): self.evaluate( ps.smart_where(tf.constant([False, False]), lambda: None, lambda: tf.zeros([3])))
def test_static_scalar_condition(self): fn_calls = [0, 0] ones = tf.ones([10]) zeros = tf.zeros([10]) def fn1(): fn_calls[0] += 1 return ones def fn2(): fn_calls[1] += 1 return zeros self.assertAllEqual(zeros, ps.smart_where(False, fn1, fn2)) self.assertEqual([0, 1], fn_calls) self.assertAllEqual(ones, ps.smart_where(True, fn1, fn2)) self.assertEqual([1, 1], fn_calls) self.assertAllEqual(zeros, ps.smart_where(tf.constant(False), fn1, fn2)) self.assertEqual([1, 2], fn_calls) self.assertAllEqual(ones, ps.smart_where(tf.constant(True), fn1, fn2)) self.assertEqual([2, 2], fn_calls) self.assertAllEqual(zeros, ps.smart_where(np.array(False), fn1, fn2)) self.assertEqual([2, 3], fn_calls) self.assertAllEqual(ones, ps.smart_where(np.array(True), fn1, fn2)) self.assertEqual([3, 3], fn_calls) self.assertAllEqual(zeros, ps.smart_where(0, fn1, fn2)) self.assertEqual([3, 4], fn_calls) self.assertAllEqual(ones, ps.smart_where(1, fn1, fn2)) self.assertEqual([4, 4], fn_calls) self.assertAllEqual(zeros, ps.smart_where(tf.constant(0), fn1, fn2)) self.assertEqual([4, 5], fn_calls) self.assertAllEqual(ones, ps.smart_where(tf.constant(1), fn1, fn2)) self.assertEqual([5, 5], fn_calls) self.assertAllEqual(zeros, ps.smart_where(np.array(0), fn1, fn2)) self.assertEqual([5, 6], fn_calls) self.assertAllEqual(ones, ps.smart_where(np.array(1), fn1, fn2)) self.assertEqual([6, 6], fn_calls)
def test_cond_x_broadcast_error(self): with self.assertRaisesOpError('Incompatible shapes'): self.evaluate( prefer_static.smart_where(tf.constant([True, True]), lambda: tf.zeros([3]), lambda: None))