예제 #1
0
 def testScalarCongruency(self):
     bijector_test_util.assert_scalar_congruency(tfb.Gumbel(loc=0.3,
                                                            scale=20.),
                                                 lower_x=1.,
                                                 upper_x=100.,
                                                 eval_func=self.evaluate,
                                                 rtol=0.05)
예제 #2
0
 def testVariableScale(self):
   x = tf.Variable(1.)
   b = tfb.Gumbel(loc=0., scale=x, validate_args=True)
   self.evaluate(x.initializer)
   self.assertIs(x, b.scale)
   self.assertEqual((), self.evaluate(b.forward(-3.)).shape)
   with self.assertRaisesOpError("Argument `scale` must be positive."):
     with tf.control_dependencies([x.assign(-1.)]):
       self.evaluate(b.forward(-3.))
예제 #3
0
 def testBijectiveAndFinite(self):
     bijector = tfb.Gumbel(loc=0., scale=3.0, validate_args=True)
     x = np.linspace(-10., 10., num=10).astype(np.float32)
     y = np.linspace(0.01, 0.99, num=10).astype(np.float32)
     bijector_test_util.assert_bijective_and_finite(bijector,
                                                    x,
                                                    y,
                                                    eval_func=self.evaluate,
                                                    event_ndims=0,
                                                    rtol=1e-3)
예제 #4
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Gumbel"):
        """Construct Gumbel distributions with location and scale `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor, the means of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Gumbel'`.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
        with tf.name_scope(name, values=[loc, scale]) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            preferred_dtype=tf.float32)
            loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype)
            with tf.control_dependencies(
                [tf.assert_positive(scale)] if validate_args else []):
                loc = tf.identity(loc, name="loc")
                scale = tf.identity(scale, name="scale")
                tf.assert_same_float_dtype([loc, scale])
                self._gumbel_bijector = bijectors.Gumbel(
                    loc=loc, scale=scale, validate_args=validate_args)

            super(Gumbel, self).__init__(
                distribution=uniform.Uniform(low=tf.zeros([], dtype=loc.dtype),
                                             high=tf.ones([], dtype=loc.dtype),
                                             allow_nan_stats=allow_nan_stats),
                # The Gumbel bijector encodes the quantile
                # function as the forward, and hence needs to
                # be inverted.
                bijector=bijectors.Invert(self._gumbel_bijector),
                batch_shape=distribution_util.get_broadcast_shape(loc, scale),
                name=name)
예제 #5
0
 def testBijector(self):
   loc = 0.3
   scale = 5.
   bijector = tfb.Gumbel(loc=loc, scale=scale, validate_args=True)
   self.assertStartsWith(bijector.name, "gumbel")
   x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32)
   # Gumbel distribution
   gumbel_dist = stats.gumbel_r(loc=loc, scale=scale)
   y = gumbel_dist.cdf(x).astype(np.float32)
   self.assertAllClose(y, self.evaluate(bijector.forward(x)))
   self.assertAllClose(x, self.evaluate(bijector.inverse(y)))
   self.assertAllClose(
       np.squeeze(gumbel_dist.logpdf(x), axis=-1),
       self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)))
   self.assertAllClose(
       self.evaluate(-bijector.inverse_log_det_jacobian(y, event_ndims=1)),
       self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=1)),
       rtol=1e-4,
       atol=0.)
예제 #6
0
 def testScalarCongruency(self):
     with self.test_session():
         assert_scalar_congruency(tfb.Gumbel(loc=0.3, scale=20.),
                                  lower_x=1.,
                                  upper_x=100.,
                                  rtol=0.02)