예제 #1
0
 def testBijector(self):
     with self.test_session():
         bijector = identity_lib.Identity()
         self.assertEqual("identity", bijector.name)
         x = [[[0.], [1.]]]
         self.assertAllEqual(x, bijector.forward(x).eval())
         self.assertAllEqual(x, bijector.inverse(x).eval())
         self.assertAllEqual(0.,
                             bijector.inverse_log_det_jacobian(x).eval())
         self.assertAllEqual(0.,
                             bijector.forward_log_det_jacobian(x).eval())
예제 #2
0
 def testScalarCongruency(self):
     with self.test_session():
         bijector = identity_lib.Identity()
         bijector_test_util.assert_scalar_congruency(bijector,
                                                     lower_x=-2.,
                                                     upper_x=2.)
  def __init__(self,
               distribution,
               bijector=None,
               batch_shape=None,
               event_shape=None,
               validate_args=False,
               name=None):
    """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`. `None` means `Identity()`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
    parameters = locals()
    name = name or (("" if bijector is None else bijector.name) +
                    distribution.name)
    with ops.name_scope(name, values=[event_shape, batch_shape]):
      # For convenience we define some handy constants.
      self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero")
      self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")

      if bijector is None:
        bijector = identity_lib.Identity(validate_args=validate_args)

      # We will keep track of a static and dynamic version of
      # self._is_{batch,event}_override. This way we can do more prior to graph
      # execution, including possibly raising Python exceptions.

      self._override_batch_shape = self._maybe_validate_shape_override(
          batch_shape, distribution.is_scalar_batch(), validate_args,
          "batch_shape")
      self._is_batch_override = _logical_not(_logical_equal(
          _ndims_from_shape(self._override_batch_shape), self._zero))
      self._is_maybe_batch_override = bool(
          tensor_util.constant_value(self._override_batch_shape) is None or
          tensor_util.constant_value(self._override_batch_shape).size != 0)

      self._override_event_shape = self._maybe_validate_shape_override(
          event_shape, distribution.is_scalar_event(), validate_args,
          "event_shape")
      self._is_event_override = _logical_not(_logical_equal(
          _ndims_from_shape(self._override_event_shape), self._zero))
      self._is_maybe_event_override = bool(
          tensor_util.constant_value(self._override_event_shape) is None or
          tensor_util.constant_value(self._override_event_shape).size != 0)

      # To convert a scalar distribution into a multivariate distribution we
      # will draw dims from the sample dims, which are otherwise iid. This is
      # easy to do except in the case that the base distribution has batch dims
      # and we're overriding event shape. When that case happens the event dims
      # will incorrectly be to the left of the batch dims. In this case we'll
      # cyclically permute left the new dims.
      self._needs_rotation = _logical_and(
          self._is_event_override,
          _logical_not(self._is_batch_override),
          _logical_not(distribution.is_scalar_batch()))
      override_event_ndims = _ndims_from_shape(self._override_event_shape)
      self._rotate_ndims = _pick_scalar_condition(
          self._needs_rotation, override_event_ndims, 0)
      # We'll be reducing the head dims (if at all), i.e., this will be []
      # if we don't need to reduce.
      self._reduce_event_indices = math_ops.range(
          self._rotate_ndims - override_event_ndims, self._rotate_ndims)

    self._distribution = distribution
    self._bijector = bijector
    super(TransformedDistribution, self).__init__(
        dtype=self._distribution.dtype,
        reparameterization_type=self._distribution.reparameterization_type,
        validate_args=validate_args,
        allow_nan_stats=self._distribution.allow_nan_stats,
        parameters=parameters,
        # We let TransformedDistribution access _graph_parents since this class
        # is more like a baseclass than derived.
        graph_parents=(distribution._graph_parents +  # pylint: disable=protected-access
                       bijector.graph_parents),
        name=name)