コード例 #1
0
def transform(x, *args, **kwargs):
    """Transform a continuous random variable to the unconstrained space.

  Transform selects among a number of defaults transformations which depend
  on the support of the provided random variable.

  Args:
    x : RandomVariable.
      Continuous random variable to transform.
    *args, **kwargs : optional.
      Arguments to overwrite when forming the ``TransformedDistribution``.
      For example, one can manually specify the transformation by
      passing in the ``bijector`` argument.

  Returns:
    RandomVariable.
    A ``TransformedDistribution`` random variable, or the provided random
    variable if no transformation was applied.

  #### Examples

  ```python
  x = Gamma(1.0, 1.0)
  y = ed.transform(x)
  sess = tf.Session()
  sess.run(y)
  -2.2279539
  ```
  """
    if len(args) != 0 or kwargs.get('bijector', None) is not None:
        return TransformedDistribution(x, *args, **kwargs)

    try:
        support = x.support
    except AttributeError as e:
        msg = """'{}' object has no 'support'
             so cannot be transformed.""".format(type(x).__name__)
        raise ValueError(msg)

    if support == '01':
        bij = bijectors.Invert(bijectors.Sigmoid())
    elif support == 'nonnegative':
        bij = bijectors.Invert(bijectors.Softplus())
    elif support == 'simplex':
        bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1))
    elif support == 'real' or support == 'multivariate_real':
        return x
    else:
        msg = "'transform' does not handle supports of type '{}'".format(
            support)
        raise NotImplementedError(msg)

    return TransformedDistribution(x, bij, *args, **kwargs)
コード例 #2
0
 def test_kwargs(self):
     with self.test_session():
         x = Normal(-100.0, 1.0)
         y = ed.transform(x, bijector=bijectors.Softplus())
         sample = y.sample(10).eval()
         self.assertTrue((sample >= 0.0).all())