Ejemplo n.º 1
0
  def __init__(self,
               df,
               validate_args=False,
               allow_nan_stats=True,
               name='Chi'):
    """Construct Chi distributions with parameter `df`.

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the result
        is undefined. When `False`, an exception is raised if one or more of the
        statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Chi'`.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([df], dtype_hint=tf.float32)
      self._df = tensor_util.convert_nonref_to_tensor(
          df, name='df', dtype=dtype)
      super(Chi, self).__init__(
          distribution=chi2.Chi2(df=self._df,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats),
          bijector=invert_bijector.Invert(
              square_bijector.Square(validate_args=validate_args)),
          validate_args=validate_args,
          parameters=parameters,
          name=name)
Ejemplo n.º 2
0
  def _sample_n(self, n, seed=None):
    # Like with the univariate Student's t, sampling can be implemented as a
    # ratio of samples from a multivariate gaussian with the appropriate
    # covariance matrix and a sample from the chi-squared distribution.
    seed = seed_stream.SeedStream(seed, salt="multivariate t")

    loc = tf.broadcast_to(self.loc, self._sample_shape())
    mvn = mvn_linear_operator.MultivariateNormalLinearOperator(
        loc=tf.zeros_like(loc), scale=self.scale)
    normal_samp = mvn.sample(n, seed=seed())

    df = tf.broadcast_to(self.df, self.batch_shape_tensor())
    chi2 = chi2_lib.Chi2(df=df)
    chi2_samp = chi2.sample(n, seed=seed())

    return (self._loc +
            normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
Ejemplo n.º 3
0
  def __init__(self,
               df,
               validate_args=False,
               allow_nan_stats=True,
               name="Chi"):
    """Construct Chi distributions with parameter `df`.

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the result
        is undefined. When `False`, an exception is raised if one or more of the
        statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Chi'`.
    """
    with tf.compat.v1.name_scope(name, values=[df]) as name:
      df = tf.convert_to_tensor(
          value=df,
          name="df",
          dtype=dtype_util.common_dtype([df], preferred_dtype=tf.float32))
      validation_assertions = [tf.compat.v1.assert_positive(df)
                              ] if validate_args else []
      with tf.control_dependencies(validation_assertions):
        self._df = tf.identity(df, name="df")

      super(Chi, self).__init__(
          distribution=chi2.Chi2(df=self._df,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 name=name),
          bijector=invert_bijector.Invert(square_bijector.Square()))