示例#1
0
 def _maybe_assert_valid_sample(self, x):
     dtype_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
     if not self.validate_args:
         return x
     return distribution_util.with_dependencies([
         assert_util.assert_positive(x),
     ], x)
示例#2
0
 def _maybe_assert_valid_sample(self, x):
   dtype_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
   if not self.validate_args:
     return x
   with tf.control_dependencies([
       assert_util.assert_positive(x)]):
     return tf.identity(x)
 def _maybe_assert_valid_sample(self, x):
     dtype_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
     if not self.validate_args:
         return []
     return [
         assert_util.assert_positive(x, message='Sample must be positive.')
     ]
示例#4
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []
    if is_init:
      if not dtype_util.is_floating(self._scale.dtype):
        raise TypeError(
            'scale.dtype={} is not a floating-point type.'.format(
                self._scale.dtype))
      if not self._scale.is_square:
        raise ValueError('scale must be square.')
      dtype_util.assert_same_float_dtype([self._df, self._scale])

    df_val = tf.get_static_value(self._df)
    dim_val = tf.compat.dimension_value(self._scale.shape[-1])
    msg = ('Degrees of freedom (`df = {}`) cannot be less than dimension of '
           'scale matrix (`scale.dimension = {}`).')
    if is_init and df_val is not None and dim_val is not None:
      df_val = np.asarray(df_val)
      dim_val = np.asarray(dim_val)
      if not dim_val.shape:
        dim_val = dim_val[np.newaxis, ...]
      if not df_val.shape:
        df_val = df_val[np.newaxis, ...]
      if np.any(df_val < dim_val):
        raise ValueError(msg.format(df_val, dim_val))

    elif self.validate_args:
      if (is_init != tensor_util.is_ref(self._df) or
          is_init != tensor_util.is_ref(self._scale)):
        df = tf.convert_to_tensor(self._df)
        dimension = self._dimension()
        assertions.append(assert_util.assert_less_equal(
            dimension, df, message=(msg.format(df, dimension))))

    return assertions
  def __init__(self,
               loc,
               scale,
               low,
               high,
               validate_args=False,
               allow_nan_stats=True,
               name='TruncatedNormal'):
    """Construct TruncatedNormal.

    All parameters of the distribution will be broadcast to the same shape,
    so the resulting distribution will have a batch_shape of the broadcast
    shape of all parameters.

    Args:
      loc: Floating point tensor; the mean of the normal distribution(s) (
        note that the mean of the resulting distribution will be different
        since it is modified by the bounds).
      scale: Floating point tensor; the std deviation of the normal
        distribution(s).
      low: `float` `Tensor` representing lower bound of the distribution's
        support. Must be such that `low < high`.
      high: `float` `Tensor` representing upper bound of the distribution's
        support. Must be such that `low < high`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked at run-time.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, low, high], tf.float32)
      self._loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      self._scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      self._low = tensor_util.convert_nonref_to_tensor(
          low, name='low', dtype=dtype)
      self._high = tensor_util.convert_nonref_to_tensor(
          high, name='high', dtype=dtype)
      dtype_util.assert_same_float_dtype(
          [self._loc, self._scale, self._low, self._high])

      super(TruncatedNormal, self).__init__(
          dtype=dtype,
          # This distribution is fully reparameterized. loc, scale have straight
          # through gradients. The gradients for the bounds are implemented
          # using custom derived expressions based on implicit gradients.
          # For the special case of lower bound zero and a positive upper bound
          # an equivalent expression can also be found in Sec 9.1.1.
          # of https://arxiv.org/pdf/1806.01851.pdf. The implementation here
          # handles arbitrary bounds.
          reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)
示例#6
0
    def __init__(self, loc=0., scale=1., validate_args=False, name="gumbel"):
        """Instantiates the `Gumbel` bijector.

    Args:
      loc: Float-like `Tensor` that is the same dtype and is
        broadcastable with `scale`.
        This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
      scale: Positive Float-like `Tensor` that is the same dtype and is
        broadcastable with `loc`.
        This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        self._graph_parents = []
        self._name = name
        self._validate_args = validate_args
        with self._name_scope("init"):
            self._loc = tf.convert_to_tensor(value=loc, name="loc")
            self._scale = tf.convert_to_tensor(value=scale, name="scale")
            dtype_util.assert_same_float_dtype([self._loc, self._scale])
            if validate_args:
                self._scale = distribution_util.with_dependencies([
                    assert_util.assert_positive(
                        self._scale, message="Argument scale was not positive")
                ], self._scale)

        super(Gumbel, self).__init__(validate_args=validate_args,
                                     forward_min_event_ndims=0,
                                     name=name)
示例#7
0
  def __init__(self,
               loc,
               scale,
               validate_args=False,
               allow_nan_stats=True,
               name="Gumbel"):
    """Construct Gumbel distributions with location and scale `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor, the means of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Gumbel'`.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
      loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
      scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype)
      with tf.control_dependencies(
          [assert_util.assert_positive(scale)] if validate_args else []):
        loc = tf.identity(loc, name="loc")
        scale = tf.identity(scale, name="scale")
        dtype_util.assert_same_float_dtype([loc, scale])
        self._gumbel_bijector = gumbel_bijector.Gumbel(
            loc=loc, scale=scale, validate_args=validate_args)

      # Because the uniform sampler generates samples in `[0, 1)` this would
      # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
      # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
      # because it is the smallest, positive, "normal" number.
      super(Gumbel, self).__init__(
          distribution=uniform.Uniform(
              low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny,
              high=tf.ones([], dtype=loc.dtype),
              allow_nan_stats=allow_nan_stats),
          # The Gumbel bijector encodes the quantile
          # function as the forward, and hence needs to
          # be inverted.
          bijector=invert_bijector.Invert(self._gumbel_bijector),
          batch_shape=distribution_util.get_broadcast_shape(loc, scale),
          parameters=parameters,
          name=name)
示例#8
0
    def __init__(self,
                 loc,
                 scale,
                 low,
                 high,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='TruncatedCauchy'):
        """Construct a TruncatedCauchy.

    All parameters of the distribution will be broadcast to the same shape,
    so the resulting distribution will have a batch_shape of the broadcast
    shape of all parameters.

    Args:
      loc: Floating point tensor; the modes of the corresponding non-truncated
        Cauchy distribution(s).
      scale: Floating point tensor; the scales of the distribution(s).
        Must contain only positive values.
      low: `float` `Tensor` representing lower bound of the distribution's
        support. Must be such that `low < high`.
      high: `float` `Tensor` representing upper bound of the distribution's
        support. Must be such that `low < high`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked at run-time.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, low, high],
                                            tf.float32)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)
            self._low = tensor_util.convert_nonref_to_tensor(low,
                                                             name='low',
                                                             dtype=dtype)
            self._high = tensor_util.convert_nonref_to_tensor(high,
                                                              name='high',
                                                              dtype=dtype)
            dtype_util.assert_same_float_dtype(
                [self._loc, self._scale, self._low, self._high])

            super(TruncatedCauchy, self).__init__(
                dtype=dtype,
                # Samples do not have gradients with respect to `_low` and `_high`.
                # TODO(b/161297284): Implement these gradients.
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
示例#9
0
    def __init__(self,
                 df,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="StudentT"):
        """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `loc`, and scale
    `scale`.

    The parameters `df`, `loc`, and `scale` must be shaped in a way that
    supports broadcasting (e.g. `df + loc + scale` is a valid operation).

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([df, loc, scale], tf.float32)
            df = tf.convert_to_tensor(value=df, name="df", dtype=dtype)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            with tf.control_dependencies(
                [assert_util.assert_positive(df)] if validate_args else []):
                self._df = tf.identity(df)
                self._loc = tf.identity(loc)
                self._scale = tf.identity(scale)
                dtype_util.assert_same_float_dtype(
                    (self._df, self._loc, self._scale))
        super(StudentT, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._df, self._loc, self._scale],
            name=name)
示例#10
0
    def __init__(self,
                 distribution,
                 shift,
                 scale,
                 tailweight=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="LambertWDistribution"):
        """Initializes the class.

    Args:
      distribution: `tf.Distribution`-like instance. Distribution F that is
        transformed to produce this Lambert W x F distribution.
      shift: shift that should be applied before & after tail transformation.
        For a location-scale family `distribution` (e.g., `Normal` or
        `StudentT`) this usually is set as the mean / location parameter. For a
        scale family `distribution` (e.g., `Gamma` or `Fisher`) this must be
        set to 0 to guarantee a proper transformation on the positive
        real-line.
      scale: scaling factor that should be applied before & after the tail
        trarnsformation.  Usually the standard deviation or scaling parameter
        of the `distribution`.
      tailweight: Tail parameter `delta` of the resulting Lambert W x F
        distribution(s).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: A name for the operation (optional).
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, shift, scale],
                                            tf.float32)
            tailweight = 0. if tailweight is None else tailweight
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._shift = tensor_util.convert_nonref_to_tensor(shift,
                                                               name="shift",
                                                               dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self.tailweight, self.shift, self.scale))
            self._allow_nan_stats = allow_nan_stats
            super(LambertWDistribution, self).__init__(
                distribution=distribution,
                bijector=tfb.LambertWTail(shift=shift,
                                          scale=scale,
                                          tailweight=tailweight,
                                          validate_args=validate_args),
                parameters=parameters,
                validate_args=validate_args,
                name=name)
示例#11
0
  def __init__(self,
               loc,
               scale,
               quadrature_size=8,
               quadrature_fn=quadrature_scheme_lognormal_quantiles,
               validate_args=False,
               allow_nan_stats=True,
               name='PoissonLogNormalQuadratureCompound'):
    """Constructs the PoissonLogNormalQuadratureCompound`.

    Note: `probs` returned by (optional) `quadrature_fn` are presumed to be
    either a length-`quadrature_size` vector or a batch of vectors in 1-to-1
    correspondence with the returned `grid`. (I.e., broadcasting is only
    partially supported.)

    Args:
      loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
        the LogNormal prior.
      scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
        the LogNormal prior.
      quadrature_size: Python `int` scalar representing the number of quadrature
        points.
      quadrature_fn: Python callable taking `loc`, `scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the LogNormal grid and corresponding normalized weight.
        Default value: `quadrature_scheme_lognormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `quadrature_grid` and `quadrature_probs` have different base
        `dtype`.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale], tf.float32)
      self._loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      self._scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      self._quadrature_fn = quadrature_fn
      dtype_util.assert_same_float_dtype([self._loc, self._scale])

      self._quadrature_size = quadrature_size

      super(PoissonLogNormalQuadratureCompound, self).__init__(
          dtype=dtype,
          reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)
示例#12
0
 def _sample_control_dependencies(self, x):
   dtype_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_non_negative(
       x, message='Sample must be non-negative.'))
   return assertions
示例#13
0
    def __init__(self,
                 df,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HalfStudentT"):
        """Construct a half-Student's t distribution with `df`, `loc` and `scale`.

        Args
        ----
            df: Floating-point `Tensor`. The degrees of freedom of the
                distribution(s). `df` must contain only positive values.
            loc: Floating-point `Tensor`; the location(s) of the distribution(s).
            scale: Floating-point `Tensor`; the scale(s) of the distribution(s).
                Must contain only positive values.
            validate_args: Python `bool`, default `False`. When `True` distribution
                parameters are checked for validity despite possibly degrading runtime
                performance. When `False` invalid inputs may silently render incorrect
                outputs. Default value: `False` (i.e. do not validate args).
            allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
                (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
                result is undefined. When `False`, an exception is raised if one or
                more of the statistic's batch members are undefined.
                Default value: `True`.
            name: Python `str` name prefixed to Ops created by this class.
                Default value: 'HalfStudentT'.

        Raises
        ------
            TypeError: if `df`, loc`, or `scale` are different dtypes
        """

        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([df, loc, scale],
                                            dtype_hint=tf.float32)
            self._df = tensor_util.convert_nonref_to_tensor(df,
                                                            name="df",
                                                            dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name="loc",
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self._df, self._loc, self._scale))
            super(HalfStudentT, self).__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name,
            )
示例#14
0
 def _parameter_control_dependencies(self, is_init):
   if is_init:
     dtype_util.assert_same_float_dtype([self.loc, self.scale])
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self._scale):
     assertions.append(assert_util.assert_positive(
         self._scale, message='Argument `scale` must be positive.'))
   return assertions
示例#15
0
    def __init__(self,
                 loc,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VonMises"):
        """Construct von Mises distributions with given location and concentration.

    The parameters `loc` and `concentration` must be shaped in a way that
    supports broadcasting (e.g. `loc + concentration` is a valid operation).

    Args:
      loc: Floating point tensor, the circular means of the distribution(s).
      concentration: Floating point tensor, the level of concentration of the
        distribution(s) around `loc`. Must take non-negative values.
        `concentration = 0` defines a Uniform distribution, while
        `concentration = +inf` indicates a Deterministic distribution at `loc`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if loc and concentration are different dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, concentration],
                                            dtype_hint=tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            concentration = tf.convert_to_tensor(value=concentration,
                                                 name="concentration",
                                                 dtype=dtype)
            with tf.control_dependencies(
                [assert_util.assert_non_negative(concentration
                                                 )] if validate_args else []):
                self._loc = tf.identity(loc, name="loc")
                self._concentration = tf.identity(concentration,
                                                  name="concentration")
                dtype_util.assert_same_float_dtype(
                    [self._loc, self._concentration])
        super(VonMises, self).__init__(
            dtype=self._concentration.dtype,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._loc, self._concentration],
            name=name)
示例#16
0
    def __init__(self,
                 concentration,
                 rate,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Gamma"):
        """Construct Gamma with `concentration` and `rate` parameters.

    The parameters `concentration` and `rate` must be shaped in a way that
    supports broadcasting (e.g. `concentration + rate` is a valid operation).

    Args:
      concentration: Floating point tensor, the concentration params of the
        distribution(s). Must contain only positive values.
      rate: Floating point tensor, the inverse scale params of the
        distribution(s). Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `concentration` and `rate` are different dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([concentration, rate], tf.float32)
            concentration = tf.convert_to_tensor(value=concentration,
                                                 name="concentration",
                                                 dtype=dtype)
            rate = tf.convert_to_tensor(value=rate, name="rate", dtype=dtype)
            with tf.control_dependencies([
                    assert_util.assert_positive(concentration),
                    assert_util.assert_positive(rate),
            ] if validate_args else []):
                self._concentration = tf.identity(concentration)
                self._rate = tf.identity(rate)
                dtype_util.assert_same_float_dtype(
                    [self._concentration, self._rate])
        super(Gamma, self).__init__(
            dtype=dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=[self._concentration, self._rate],
            name=name)
    def __init__(self,
                 shift,
                 scale,
                 tailweight,
                 validate_args=False,
                 name="lambertw_tail"):
        """Construct a location scale heavy-tail Lambert W bijector.

    The parameters `shift`, `scale`, and `tail` must be shaped in a way that
    supports broadcasting (e.g. `shift + scale + tail` is a valid operation).

    Args:
      shift: Floating point tensor; the shift for centering (uncentering) the
        input (output) random variable(s).
      scale: Floating point tensor; the scaling (unscaling) of the input
        (output) random variable(s). Must contain only positive values.
      tailweight: Floating point tensor; the tail behaviors of the output random
        variable(s).  Must contain only non-negative values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `shift` and `scale` and `tail` have different `dtype`.
    """
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, shift, scale],
                                            tf.float32)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._shift = tensor_util.convert_nonref_to_tensor(shift,
                                                               name="shift",
                                                               dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self._tailweight, self._shift, self._scale))

            self._shift_and_scale = chain.Chain(
                [tfb_shift.Shift(self._shift),
                 tfb_scale.Scale(self._scale)])
            # 'bijectors' argument in tfb.Chain super class are executed in reverse(!)
            # order.  Hence the ordering in the list must be (3,2,1), not (1,2,3).
            super(LambertWTail, self).__init__(bijectors=[
                self._shift_and_scale,
                _HeavyTailOnly(tailweight=self._tailweight),
                invert.Invert(self._shift_and_scale)
            ],
                                               validate_args=validate_args)
示例#18
0
  def test_assert_same_float_dtype(self):
    self.assertIs(tf.float32, dtype_util.assert_same_float_dtype(None, None))
    self.assertIs(tf.float32, dtype_util.assert_same_float_dtype([], None))
    self.assertIs(
        tf.float32, dtype_util.assert_same_float_dtype([], tf.float32))
    self.assertIs(
        tf.float32, dtype_util.assert_same_float_dtype(None, tf.float32))
    self.assertIs(
        tf.float32, dtype_util.assert_same_float_dtype([None, None], None))
    self.assertIs(
        tf.float32,
        dtype_util.assert_same_float_dtype([None, None], tf.float32))

    const_float = tf.constant(3.0, dtype=tf.float32)
    self.assertIs(
        tf.float32,
        dtype_util.assert_same_float_dtype([const_float], tf.float32))
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [const_float], tf.int32)

    if not hasattr(tf, 'SparseTensor'):
      # No SparseTensor in numpy/jax mode.
      return
    sparse_float = tf.SparseTensor(
        tf.constant([[111], [232]], tf.int64),
        tf.constant([23.4, -43.2], tf.float32),
        tf.constant([500], tf.int64))
    self.assertIs(
        tf.float32,
        dtype_util.assert_same_float_dtype([sparse_float], tf.float32))
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [sparse_float], tf.int32)
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [const_float, None, sparse_float], tf.float64)

    self.assertIs(
        tf.float32,
        dtype_util.assert_same_float_dtype([const_float, sparse_float]))
    self.assertIs(
        tf.float32,
        dtype_util.assert_same_float_dtype(
            [const_float, sparse_float], tf.float32))

    const_int = tf.constant(3, dtype=tf.int32)
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [sparse_float, const_int])
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [sparse_float, const_int], tf.int32)
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [sparse_float, const_int], tf.float32)
    self.assertRaises(ValueError, dtype_util.assert_same_float_dtype,
                      [const_int])
示例#19
0
    def __init__(self,
                 loc,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="InverseGaussian"):
        """Constructs inverse Gaussian distribution with `loc` and `concentration`.

    Args:
      loc: Floating-point `Tensor`, the loc params. Must contain only positive
        values.
      concentration: Floating-point `Tensor`, the concentration params.
        Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False` (i.e. do not validate args).
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'InverseGaussian'.
    """
        parameters = dict(locals())
        with tf.name_scope(name):
            dtype = dtype_util.common_dtype([loc, concentration],
                                            preferred_dtype=tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            concentration = tf.convert_to_tensor(value=concentration,
                                                 name="concentration",
                                                 dtype=dtype)
            with tf.control_dependencies([
                    assert_util.assert_positive(loc),
                    assert_util.assert_positive(concentration)
            ] if validate_args else []):
                self._loc = tf.identity(loc, name="loc")
                self._concentration = tf.identity(concentration,
                                                  name="concentration")
            dtype_util.assert_same_float_dtype(
                [self._loc, self._concentration])
        super(InverseGaussian, self).__init__(
            dtype=self._loc.dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._loc, self._concentration],
            name=name)
示例#20
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Laplace"):
        """Construct Laplace distribution with parameters `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g., `loc / scale` is a valid operation).

    Args:
      loc: Floating point tensor which characterizes the location (center)
        of the distribution.
      scale: Positive floating point tensor which characterizes the spread of
        the distribution.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `loc` and `scale` are of different dtype.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale], tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            with tf.control_dependencies(
                [assert_util.assert_positive(scale)] if validate_args else []):
                self._loc = tf.identity(loc)
                self._scale = tf.identity(scale)
                dtype_util.assert_same_float_dtype([self._loc, self._scale])
            super(Laplace,
                  self).__init__(dtype=dtype,
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 parameters=parameters,
                                 graph_parents=[self._loc, self._scale],
                                 name=name)
示例#21
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HalfCauchy"):
        """Construct a half-Cauchy distribution with `loc` and `scale`.

    Args:
      loc: Floating-point `Tensor`; the location(s) of the distribution(s).
      scale: Floating-point `Tensor`; the scale(s) of the distribution(s).
        Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs. Default value: `False` (i.e. do not validate args).
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: 'HalfCauchy'.

    Raises:
      TypeError: if `loc` and `scale` have different `dtype`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            preferred_dtype=tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            with tf.control_dependencies(
                [assert_util.assert_positive(scale)] if validate_args else []):
                self._loc = tf.identity(loc, name="loc")
                self._scale = tf.identity(scale, name="loc")
            dtype_util.assert_same_float_dtype([self._loc, self._scale])
        super(HalfCauchy, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._loc, self._scale],
            name=name)
示例#22
0
    def __init__(self,
                 low=0.,
                 high=1.,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Uniform"):
        """Initialize a batch of Uniform distributions.

    Args:
      low: Floating point tensor, lower boundary of the output interval. Must
        have `low < high`.
      high: Floating point tensor, upper boundary of the output interval. Must
        have `low < high`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      InvalidArgumentError: if `low >= high` and `validate_args=False`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([low, high], tf.float32)
            low = tf.convert_to_tensor(value=low, name="low", dtype=dtype)
            high = tf.convert_to_tensor(value=high, name="high", dtype=dtype)
            with tf.control_dependencies([  # pylint: disable=g-long-ternary
                    assert_util.assert_less(
                        low,
                        high,
                        message="uniform not defined when low >= high.")
            ] if validate_args else []):
                self._low = tf.identity(low)
                self._high = tf.identity(high)
                dtype_util.assert_same_float_dtype([self._low, self._high])
        super(Uniform, self).__init__(
            dtype=self._low.dtype,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._low, self._high],
            name=name)
示例#23
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='Cauchy'):
        """Construct Cauchy distributions.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor; the modes of the distribution(s).
      scale: Floating point tensor; the locations of the distribution(s).
        Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `loc` and `scale` have different `dtype`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale], tf.float32)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype([self._loc, self._scale])
            super(Cauchy,
                  self).__init__(dtype=self._scale.dtype,
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 parameters=parameters,
                                 name=name)
示例#24
0
    def __init__(self,
                 loc,
                 scale,
                 tailweight=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="LambertWNormal"):
        """Initializes the class.

    See `tfp.distributions.LambertWDistribution` for details.

    Args:
      loc: location parameter `loc` of the Normal distribution(s). This
        coincides with the location parameter of the resulting LambertWNormal.
      scale: scale parameter `scale` of the Normal distribution(s).
      tailweight: Tail parameter `delta` of the distribution(s). If `None`, it
        defaults to 0, which implies LambertWNormal == Normal.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: A name for the operation (optional).
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, loc, scale],
                                            tf.float32)
            super(LambertWNormal,
                  self).__init__(distribution=normal.Normal(loc=loc,
                                                            scale=scale),
                                 shift=loc,
                                 scale=scale,
                                 tailweight=tailweight,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 name=name)
            self._parameters = parameters
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name="loc",
                                                             dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self.tailweight, self.loc, self.scale))
示例#25
0
    def __init__(self,
                 scale=1.,
                 concentration=1.,
                 validate_args=False,
                 name="weibull"):
        """Instantiates the `Weibull` bijector.

    Args:
      scale: Positive Float-type `Tensor` that is the same dtype and is
        broadcastable with `concentration`.
        This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
      concentration: Positive Float-type `Tensor` that is the same dtype and is
        broadcastable with `scale`.
        This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        self._graph_parents = []
        self._name = name
        self._validate_args = validate_args
        with self._name_scope("init"):
            self._scale = tf.convert_to_tensor(value=scale, name="scale")
            self._concentration = tf.convert_to_tensor(value=concentration,
                                                       name="concentration")
            dtype_util.assert_same_float_dtype(
                [self._scale, self._concentration])
            if validate_args:
                self._scale = distribution_util.with_dependencies([
                    assert_util.assert_positive(
                        self._scale, message="Argument scale was not positive")
                ], self._scale)
                self._concentration = distribution_util.with_dependencies([
                    assert_util.assert_positive(
                        self._concentration,
                        message="Argument concentration was not positive")
                ], self._concentration)

        super(Weibull, self).__init__(forward_min_event_ndims=0,
                                      validate_args=validate_args,
                                      name=name)
示例#26
0
    def __init__(self,
                 skewness=None,
                 tailweight=None,
                 validate_args=False,
                 name="SinhArcsinh"):
        """Instantiates the `SinhArcsinh` bijector.

    Args:
      skewness:  Skewness parameter.  Float-type `Tensor`.  Default is `0`
        of type `float32`.
      tailweight:  Tailweight parameter.  Positive `Tensor` of same `dtype` as
        `skewness` and broadcastable `shape`.  Default is `1` of type `float32`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        self._graph_parents = []
        self._name = name
        self._validate_args = validate_args
        with self._name_scope("init"):
            tailweight = 1. if tailweight is None else tailweight
            skewness = 0. if skewness is None else skewness
            self._skewness = tf.convert_to_tensor(value=skewness,
                                                  name="skewness")
            self._tailweight = tf.convert_to_tensor(value=tailweight,
                                                    name="tailweight",
                                                    dtype=self._skewness.dtype)
            dtype_util.assert_same_float_dtype(
                [self._skewness, self._tailweight])
            if validate_args:
                self._tailweight = distribution_util.with_dependencies([
                    assert_util.assert_positive(
                        self._tailweight,
                        message="Argument tailweight was not positive")
                ], self._tailweight)
        super(SinhArcsinh, self).__init__(forward_min_event_ndims=0,
                                          validate_args=validate_args,
                                          name=name)
示例#27
0
  def __init__(self,
               low=0.,
               high=1.,
               peak=0.5,
               validate_args=False,
               allow_nan_stats=True,
               name="Triangular"):
    """Initialize a batch of Triangular distributions.

    Args:
      low: Floating point tensor, lower boundary of the output interval. Must
        have `low < high`.
        Default value: `0`.
      high: Floating point tensor, upper boundary of the output interval. Must
        have `low < high`.
        Default value: `1`.
      peak: Floating point tensor, mode of the output interval. Must have
        `low <= peak` and `peak <= high`.
        Default value: `0.5`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Triangular'`.

    Raises:
      InvalidArgumentError: if `validate_args=True` and one of the following is
        True:
        * `low >= high`.
        * `peak > high`.
        * `low > peak`.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([low, high, peak], tf.float32)
      low = tf.convert_to_tensor(value=low, name="low", dtype=dtype)
      high = tf.convert_to_tensor(value=high, name="high", dtype=dtype)
      peak = tf.convert_to_tensor(value=peak, name="peak", dtype=dtype)

      with tf.control_dependencies([
          assert_util.assert_less(
              low, high, message="triangular not defined when low >= high."),
          assert_util.assert_less_equal(
              low, peak, message="triangular not defined when low > peak."),
          assert_util.assert_less_equal(
              peak, high, message="triangular not defined when peak > high."),
      ] if validate_args else []):
        self._low = tf.identity(low, name="low")
        self._high = tf.identity(high, name="high")
        self._peak = tf.identity(peak, name="peak")
        dtype_util.assert_same_float_dtype(
            [self._low, self._high, self._peak])
    super(Triangular, self).__init__(
        dtype=self._low.dtype,
        reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._low, self._high, self._peak],
        name=name)
示例#28
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='Moyal'):
        """Construct Moyal distributions with location and scale `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor, the means of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Moyal'`.

    Raises:
      TypeError: if loc and scale are different dtypes.


    #### References

    [1] J.E. Moyal, "XXX. Theory of ionization fluctuations",
       The London, Edinburgh, and Dublin Philosophical Magazine
       and Journal of Science.
       https://www.tandfonline.com/doi/abs/10.1080/14786440308521076
    [2] G. Cordeiro, J. Nobre, R. Pescim, E. Ortega,
        "The beta Moyal: a useful skew distribution",
        https://www.arpapress.com/Volumes/Vol10Issue2/IJRRAS_10_2_02.pdf
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       name='loc',
                                                       dtype=dtype)
            scale = tensor_util.convert_nonref_to_tensor(scale,
                                                         name='scale',
                                                         dtype=dtype)
            dtype_util.assert_same_float_dtype([loc, scale])
            # Positive scale is asserted by the incorporated Moyal bijector.
            self._moyal_bijector = moyal_cdf_bijector.MoyalCDF(
                loc=loc, scale=scale, validate_args=validate_args)

            # Because the uniform sampler generates samples in `[0, 1)` this would
            # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
            # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
            # because it is the smallest, positive, 'normal' number.
            batch_shape = distribution_util.get_broadcast_shape(loc, scale)
            super(Moyal, self).__init__(
                # TODO(b/137665504): Use batch-adding meta-distribution to set the
                # batch shape instead of tf.ones.
                distribution=uniform.Uniform(low=np.finfo(
                    dtype_util.as_numpy_dtype(dtype)).tiny,
                                             high=tf.ones(batch_shape,
                                                          dtype=dtype),
                                             allow_nan_stats=allow_nan_stats),
                # The Moyal bijector encodes the CDF function as the forward,
                # and hence needs to be inverted.
                bijector=invert_bijector.Invert(self._moyal_bijector,
                                                validate_args=validate_args),
                parameters=parameters,
                name=name)
示例#29
0
  def __init__(self,
               loc,
               scale,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='GeneralizedExtremeValue'):
    """Construct generalized extreme value distribution.

    The parameters `loc`, `scale`, and `concentration` must be shaped in a way
    that supports broadcasting (e.g. `loc + scale` + `concentration` is valid).

    Args:
      loc: Floating point tensor, the location parameter of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      concentration: Floating point tensor, the concentration of
        the distribution(s).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'GeneralizedExtremeValue'`.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, concentration],
                                      dtype_hint=tf.float32)
      loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      concentration = tensor_util.convert_nonref_to_tensor(
          concentration, name='concentration', dtype=dtype)
      dtype_util.assert_same_float_dtype([loc, scale, concentration])
      # Positive scale is asserted by the incorporated GEV bijector.
      self._gev_bijector = gev_cdf_bijector.GeneralizedExtremeValueCDF(
          loc=loc, scale=scale, concentration=concentration,
          validate_args=validate_args)

      batch_shape = distribution_util.get_broadcast_shape(loc, scale,
                                                          concentration)
      # Because the uniform sampler generates samples in `[0, 1)` this would
      # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
      # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
      # because it is the smallest, positive, 'normal' number.
      super(GeneralizedExtremeValue, self).__init__(
          # TODO(b/137665504): Use batch-adding meta-distribution to set the
          # batch shape instead of tf.ones.
          distribution=uniform.Uniform(
              low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny,
              high=tf.ones(batch_shape, dtype=dtype),
              allow_nan_stats=allow_nan_stats),
          # The GEV bijector encodes the CDF function as the forward,
          # and hence needs to be inverted.
          bijector=invert_bijector.Invert(
              self._gev_bijector, validate_args=validate_args),
          parameters=parameters,
          name=name)
示例#30
0
  def __init__(self,
               mean_direction,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='VonMisesFisher'):
    """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([mean_direction, concentration],
                                      tf.float32)
      mean_direction = tf.convert_to_tensor(
          mean_direction, name='mean_direction', dtype=dtype)
      concentration = tf.convert_to_tensor(
          concentration, name='concentration', dtype=dtype)
      assertions = [
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'),
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'),
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length')
      ] if validate_args else []
      static_event_dim = tf.compat.dimension_value(
          tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1])
      if static_event_dim is not None and static_event_dim > 5:
        raise ValueError('vMF ndims > 5 is not currently supported')
      elif validate_args:
        assertions += [
            assert_util.assert_less_equal(
                tf.shape(mean_direction)[-1],
                5,
                message='vMF ndims > 5 is not currently supported')
        ]
      with tf.control_dependencies(assertions):
        self._mean_direction = tf.identity(mean_direction)
        self._concentration = tf.identity(concentration)
      dtype_util.assert_same_float_dtype(
          [self._mean_direction, self._concentration])
      # mean_direction is always reparameterized.
      # concentration is only for event_dim==3, via an inversion sampler.
      reparameterization_type = (
          reparameterization.FULLY_REPARAMETERIZED
          if static_event_dim == 3 else
          reparameterization.NOT_REPARAMETERIZED)
      super(VonMisesFisher, self).__init__(
          dtype=self._concentration.dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          name=name)