示例#1
0
  def testEventShape(self, num_or_size_splits, expected_split_sizes):
    num_or_size_splits = self.build_input(num_or_size_splits)
    total_size = np.sum(expected_split_sizes)
    shape_in_static = tf.TensorShape([total_size, 2])
    shape_out_static = [
        tf.TensorShape([d, 2]) for d in expected_split_sizes]
    bijector = tfb.Split(
        num_or_size_splits=num_or_size_splits, axis=-2, validate_args=True)

    # Test that forward_ and inverse_event_shape are correct when
    # event_shape_in/_out are statically known, even when the input shapes
    # are only partially specified.
    self.assertAllEqual(
        bijector.forward_event_shape(shape_in_static), shape_out_static)
    self.assertEqual(
        bijector.inverse_event_shape(shape_out_static), shape_in_static)

    # Shape is always known for splitting in eager mode, so we skip these tests.
    if tf.executing_eagerly():
      return
    self.assertAllEqual(
        [s.as_list() for s in bijector.forward_event_shape(
            tf.TensorShape([total_size, None]))],
        [[d, None] for d in expected_split_sizes])

    if bijector.split_sizes is None:
      static_split_sizes = tensorshape_util.constant_value_as_shape(
          expected_split_sizes).as_list()
    else:
      static_split_sizes = tensorshape_util.constant_value_as_shape(
          num_or_size_splits).as_list()

    static_total_size = None if None in static_split_sizes else total_size

    # Test correctness with an inverse input dimension of None that coincides
    # with the `-1` element in not-fully specified `split_sizes`
    shape_with_maybe_unknown_dim = (
        [[None, 3]] + [[d, 3] for d in expected_split_sizes[1:]])
    self.assertAllEqual(
        bijector.inverse_event_shape(shape_with_maybe_unknown_dim).as_list(),
        [static_total_size, 3])

    # Test correctness with an input dimension of None that does not coincide
    # with a `-1` split_size.
    shape_with_deducable_dim = [[d, 3] for d in expected_split_sizes]
    shape_with_deducable_dim[2] = [None, 3]
    self.assertAllEqual(
        bijector.inverse_event_shape(
            shape_with_deducable_dim).as_list(), [total_size, 3])

    # Test correctness for an input shape of known rank only.
    if bijector.split_sizes is not None:
      shape_with_unknown_total = (
          [[d, None] for d in static_split_sizes])
    else:
      shape_with_unknown_total = [[None, None]] * len(expected_split_sizes)
    self.assertAllEqual(
        [s.as_list() for s in bijector.forward_event_shape(
            tf.TensorShape([None, None]))],
        shape_with_unknown_total)
示例#2
0
    def _inverse_event_shape_tensor(self, output_shapes):
        """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.

    Args:
      output_shapes: An iterable of `Tensor`, `int32` vectors indicating
        event-shapes passed into `inverse` function. The length of the iterable
        must be equal to the number of splits.

    Returns:
      inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
        event-portion shape after applying `inverse`.
    """
        # Validate `output_shapes` statically if possible and get assertions.
        is_validated = self._validate_output_shapes([
            tensorshape_util.constant_value_as_shape(s) for s in output_shapes
        ])
        if is_validated or not self.validate_args:
            assertions = []
        else:
            assertions = self._validate_output_shape_tensors(output_shapes)

        with tf.control_dependencies(assertions):
            total_size = tf.reduce_sum([t[self.axis] for t in output_shapes])
            inverse_event_shape = tf.tensor_scatter_nd_update(
                output_shapes[0],
                [[prefer_static.rank_from_shape(output_shapes[0]) + self.axis]
                 ], [total_size])
            return tf.identity(
                tf.convert_to_tensor(inverse_event_shape,
                                     dtype_hint=tf.int32,
                                     name='inverse_event_shape'))
示例#3
0
def validate_init_args_statically(distribution, batch_shape):
    """Helper to __init__ which makes or raises assertions."""
    if tensorshape_util.rank(batch_shape.shape) is not None:
        if tensorshape_util.rank(batch_shape.shape) != 1:
            raise ValueError('`batch_shape` must be a vector '
                             '(saw rank: {}).'.format(
                                 tensorshape_util.rank(batch_shape.shape)))

    batch_shape_static = tensorshape_util.constant_value_as_shape(batch_shape)
    batch_size_static = tensorshape_util.num_elements(batch_shape_static)
    dist_batch_size_static = tensorshape_util.num_elements(
        distribution.batch_shape)

    if batch_size_static is not None and dist_batch_size_static is not None:
        if batch_size_static != dist_batch_size_static:
            raise ValueError('`batch_shape` size ({}) must match '
                             '`distribution.batch_shape` size ({}).'.format(
                                 batch_size_static, dist_batch_size_static))

    if tensorshape_util.dims(batch_shape_static) is not None:
        if any(
                tf.compat.dimension_value(dim) is not None
                and tf.compat.dimension_value(dim) < 1
                for dim in batch_shape_static):
            raise ValueError('`batch_shape` elements must be >=-1.')
示例#4
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
    def test_constant_value_as_shape(self):
        x = np.array([1, 2, 3, 4], dtype=np.int32)
        s = tensorshape_util.constant_value_as_shape(x)
        self.assertIsInstance(s, tf.TensorShape)
        self.assertAllEqual(x, s)

        x = tf.Variable([3])
        s = tensorshape_util.constant_value_as_shape(x)
        # `s` could be `TensorShape(None)` or `TensorShape([None])`, depending on
        # whether or not we're executing eagerly.  We could improve
        # `constant_value_as_shape` to always return `TensorShape([None])`.
        self.assertFalse(s.is_fully_defined())

        self.assertEqual(
            tf.TensorShape(None),
            tensorshape_util.constant_value_as_shape(
                tf.Variable([7, 2], shape=tf.TensorShape([None]))))
示例#6
0
    def _forward_event_shape(self, input_shape):
        """Shape of a single sample from a single batch as a list of `TensorShape`s.

    Same meaning as `forward_event_shape_tensor`. May be only partially defined.

    Args:
      input_shape: `TensorShape` indicating event-portion shape passed into
        `forward` function.

    Returns:
      forward_event_shape: A list of (possibly unknown) `TensorShape`s
        indicating event-portion shape after applying `forward`. The length of
        the list is equal to the number of splits.
    """
        self._validate_input_shape(input_shape)
        if tensorshape_util.rank(input_shape) is None:
            output_shapes = [None] * self.num_splits
        else:
            input_shape = tf.TensorShape(input_shape).as_list()
            axis = tf.get_static_value(self.axis)

            if self.split_sizes is None:
                # Calculate `split_sizes` from `input_shape` and `num_splits`, if
                # possible.
                split_size = (None if input_shape[axis] is None else
                              input_shape[axis] // self.num_splits)
                split_sizes = [split_size] * self.num_splits

            else:
                static_split_sizes = tf.get_static_value(self.split_sizes)
                if static_split_sizes is None:
                    static_split_sizes = [None] * self.num_splits
                split_sizes = tensorshape_util.constant_value_as_shape(
                    static_split_sizes).as_list()

                # If there is a single unknown element of `split_sizes` and the input
                # dimension is known, set the unknown element equal to the difference
                # between the input dimension and the sum of the known elements of
                # `split_sizes`.
                if sum(s is None for s in split_sizes) == 1:
                    if input_shape is not None and input_shape[
                            axis] is not None:
                        total_size = input_shape[axis]
                        deduced_split_size = (
                            total_size -
                            sum(s for s in split_sizes if s is not None))
                        split_sizes = [
                            deduced_split_size if s is None else s
                            for s in split_sizes
                        ]

            output_shapes = []
            for split_size in split_sizes:
                output_shape = input_shape[:]
                output_shape[axis] = split_size
                output_shapes.append(output_shape)

        return [tf.TensorShape(shape) for shape in output_shapes]
示例#7
0
    def _forward_event_shape_tensor(self, input_shape):
        """Shape of a sample from a single batch as a list of `int32` 1D `Tensor`s.

    Args:
      input_shape: `Tensor`, `int32` vector indicating event-portion shape
        passed into `forward` function.

    Returns:
      forward_event_shape_tensor: A list of `Tensor`, `int32` vectors indicating
        event-portion shape after applying `forward`. The length of the list is
        equal to the number of splits.
    """
        # Validate `input_shape` statically if possible and get assertions.
        is_validated = self._validate_input_shape(
            tensorshape_util.constant_value_as_shape(input_shape))
        if is_validated or not self.validate_args:
            assertions = []
        else:
            assertions = self._validate_input_shape_tensor(input_shape)

        with tf.control_dependencies(assertions):
            if self.split_sizes is None:
                split_sizes = tf.convert_to_tensor(
                    [input_shape[self.axis] // self.num_splits] *
                    self.num_splits)
            else:
                # Deduce the value of the unknown element of `split_sizes`, if any.
                split_sizes = tf.convert_to_tensor(self.split_sizes)
                split_sizes = tf.where(
                    split_sizes < 0,
                    input_shape[self.axis] - tf.reduce_sum(split_sizes) -
                    1,  # Cancel the unknown size `-1`.
                    split_sizes)

            # Each element of the `output_shape_tensor` list is equal to the
            # `input_shape`, with the corresponding element of `split_sizes`
            # substituted in the `axis` position.
            positive_axis = prefer_static.rank_from_shape(
                input_shape) + self.axis
            tiled_input_shape = tf.tile(input_shape[tf.newaxis, :],
                                        [self.num_splits, 1])
            fused_output_shapes = tf.concat([
                tiled_input_shape[:, :positive_axis], split_sizes[...,
                                                                  tf.newaxis],
                tiled_input_shape[:, positive_axis + 1:]
            ],
                                            axis=1)

            output_shapes = tf.unstack(fused_output_shapes,
                                       num=self.num_splits)
            return [
                tf.identity(
                    tf.convert_to_tensor(t,
                                         dtype_hint=tf.int32,
                                         name='forward_event_shape'))
                for t in output_shapes
            ]
示例#8
0
 def _event_shape(self):
   s = tf.get_static_value(self.sample_shape)
   if tensorshape_util.rank(s) == 1:
     sample_shape = tf.TensorShape(s)
   else:
     sample_shape = tensorshape_util.constant_value_as_shape(self.sample_shape)
   if (tensorshape_util.rank(sample_shape) is None or
       tensorshape_util.rank(self.distribution.event_shape) is None):
     return tf.TensorShape(None)
   return tensorshape_util.concatenate(sample_shape,
                                       self.distribution.event_shape)
示例#9
0
    def __init__(self,
                 distribution,
                 batch_shape,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct BatchReshape distribution.

    Args:
      distribution: The base distribution instance to reshape. Typically an
        instance of `Distribution`.
      batch_shape: Positive `int`-like vector-shaped `Tensor` representing
        the new shape of the batch dimensions. Up to one dimension may contain
        `-1`, meaning the remainder of the batch size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: The name to give Ops created by the initializer.
        Default value: `"BatchReshape" + distribution.name`.

    Raises:
      ValueError: if `batch_shape` is not a vector.
      ValueError: if `batch_shape` has non-positive elements.
      ValueError: if `batch_shape` size is not the same as a
        `distribution.batch_shape` size.
    """
        parameters = dict(locals())
        name = name or 'BatchReshape' + distribution.name
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([batch_shape], dtype_hint=tf.int32)
            # The unexpanded batch shape may contain up to one dimension of -1.
            self._batch_shape_unexpanded = tensor_util.convert_nonref_to_tensor(
                batch_shape,
                dtype=dtype,
                name='batch_shape',
                as_shape_tensor=True)
            validate_init_args_statically(distribution,
                                          self._batch_shape_unexpanded)
            self._distribution = distribution
            self._batch_shape_static = tensorshape_util.constant_value_as_shape(
                self._batch_shape_unexpanded)
            super(BatchReshape, self).__init__(
                dtype=distribution.dtype,
                reparameterization_type=distribution.reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
示例#10
0
    def __init__(self,
                 dimension,
                 batch_shape=tuple(),
                 dtype=tf.float32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='SphericalUniform'):
        """Creates a new `SphericalUniform` instance.

    Args:
      dimension: Python `int`. The dimension of the embedded space where the
        sphere resides.
      batch_shape: Positive `int`-like vector-shaped `Tensor` representing
        the new shape of the batch dimensions.
        Default value: [].
      dtype: DType of the generated samples.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if dimension < 0:
                raise ValueError(
                    'Cannot sample negative-dimension unit vectors.')
            shape_dtype = dtype_util.common_dtype([batch_shape],
                                                  dtype_hint=tf.int32)
            self._dimension = dimension
            self._batch_shape_parameter = tensor_util.convert_nonref_to_tensor(
                batch_shape,
                dtype=shape_dtype,
                name='batch_shape',
                as_shape_tensor=True)
            self._batch_shape_static = tensorshape_util.constant_value_as_shape(
                self._batch_shape_parameter)

            super(SphericalUniform,
                  self).__init__(dtype=dtype,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 parameters=parameters,
                                 name=name)
示例#11
0
    def _inverse_event_shape(self, output_shapes):
        """Shape of a sample from a single batch as a [nested] `TensorShape`.

    Same meaning as `inverse_event_shape_tensor`. May be only partially defined.

    Args:
      output_shapes: Iterable of `TensorShape`s indicating the event shapes
        passed into `inverse` function. The length of the iterable must be equal
        to the number of splits.

    Returns:
      inverse_event_shape: `TensorShape` indicating event-portion shape after
        applying `inverse`. Possibly unknown.
    """
        self._validate_output_shapes(output_shapes)
        shapes = []
        for s in output_shapes:
            if tensorshape_util.rank(s) is None:
                return tf.TensorShape(None)
            shapes.append(tf.TensorShape(s).as_list())
        axis = tf.get_static_value(self.axis)

        if self.split_sizes is None:
            split_size = None
            for shape in output_shapes:
                if shape[axis] is not None:
                    split_size = shape[axis]
            split_sizes = [split_size] * self.num_splits
        else:
            static_split_sizes = tf.get_static_value(self.split_sizes)
            if static_split_sizes is None:
                static_split_sizes = [None] * self.num_splits
            split_sizes = tensorshape_util.constant_value_as_shape(
                static_split_sizes).as_list()

        # Deduce as much static information about `inverse_event_shape` as possible.
        # If all elements of `split_sizes` are known, the concatenated dimension
        # of `inverse_event_shape` is the sum of `split_sizes`.
        if not any(s is None for s in split_sizes):
            total_size = sum(split_sizes)
        else:
            # If at least one of `split_sizes` and `output_shape[axis]` is known
            # for each split, we can determine `total_size`.
            total_size = 0
            for split, output_shape in zip(split_sizes, shapes):
                if split is None and output_shape[axis] is None:
                    total_size = None
                    break
                total_size += split or output_shape[axis]

        shape = shapes[0]
        shape[axis] = total_size
        return tf.TensorShape(shape)
  def test_transform_joint_to_joint(self, split_sizes):
    dist_batch_shape = tf.nest.pack_sequence_as(
        split_sizes,
        [tensorshape_util.constant_value_as_shape(s)
         for s in [[2, 3], [2, 1], [1, 3]]])
    bijector_batch_shape = [1, 3]

    # Build a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.random.uniform(
                minval=1., maxval=2.,
                shape=batch_shape + [size], seed=seed())),
        split_sizes, dist_batch_shape)
    if isinstance(split_sizes, dict):
      base_dist = tfd.JointDistributionNamed(component_dists)
    else:
      base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform the distribution by applying a separate bijector to each part.
    bijectors = [tfb.Exp(),
                 tfb.Scale(
                     tf.random.uniform(
                         minval=1., maxval=2.,
                         shape=bijector_batch_shape, seed=seed())),
                 tfb.Reshape([2, 1])]
    bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors),
                            validate_args=True)

    # Transform a joint distribution that has different batch shape components
    transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

    self.assertRegex(
        str(transformed_dist),
        '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name))

    self.assertAllEqualNested(
        transformed_dist.event_shape,
        bijector.forward_event_shape(base_dist.event_shape))
    self.assertAllEqualNested(*self.evaluate((
        transformed_dist.event_shape_tensor(),
        bijector.forward_event_shape_tensor(base_dist.event_shape_tensor()))))

    # Test that the batch shape components of the input are the same as those of
    # the output.
    self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape)
    self.assertAllEqualNested(
        self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape)
    self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)
示例#13
0
 def _batch_shape(self):
     # If there's a chance that the batch_shape has been overridden, we return
     # what we statically know about the `batch_shape_override`. This works
     # because: `_is_maybe_batch_override` means `static_override` is `None` or a
     # non-empty list, i.e., we don't statically know the `batch_shape` or we do.
     #
     # Notice that this implementation parallels the `_event_shape` except that
     # the `bijector` doesn't get to alter the `batch_shape`. Recall that
     # `batch_shape` is a property of a distribution while `event_shape` is
     # shared between both the `distribution` instance and the `bijector`.
     static_override = tensorshape_util.constant_value_as_shape(
         self._override_batch_shape)
     return (static_override if self._is_maybe_batch_override else
             self.distribution.batch_shape)
示例#14
0
 def _event_shape(self):
     # If there's a chance that the event_shape has been overridden, we return
     # what we statically know about the `event_shape_override`. This works
     # because: `_is_maybe_event_override` means `static_override` is `None` or a
     # non-empty list, i.e., we don't statically know the `event_shape` or we do.
     #
     # Since the `bijector` may change the `event_shape`, we then forward what we
     # know to the bijector. This allows the `bijector` to have final say in the
     # `event_shape`.
     static_override = tensorshape_util.constant_value_as_shape(
         self._override_event_shape)
     return self.bijector.forward_event_shape(
         static_override if self._is_maybe_event_override else self.
         distribution.event_shape)
示例#15
0
  def __init__(self,
               loc,
               presoftplus_scale,
               batch_shape=tuple(),
               dtype=tf.float32,
               validate_args=False,
               allow_nan_stats=True,
               name='Radial'):
    r"""Constructor.

    Args:
      loc: `Tensor` representing the mean of the distribution.
      presoftplus_scale: `Tensor` representing the pre-softplus scale, `\rho`.
      batch_shape: Positive `int`-like vector-shaped `Tensor` representing
        the new shape of the batch dimensions. Default value: [].
      dtype: the data type of the distribution.
      validate_args: Python `bool`, default `False`. When `True`, distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
      (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result
      is undefined. When `False`, an exception is raised if one or more of the
      statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event
      dimension.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      shape_dtype = dtype_util.common_dtype([batch_shape], dtype_hint=tf.int32)
      self._loc = loc
      self._presoftplus_scale = presoftplus_scale
      self._batch_shape_parameter = tensor_util.convert_nonref_to_tensor(
          batch_shape, dtype=shape_dtype, name='batch_shape')
      self._batch_shape_static = (
          tensorshape_util.constant_value_as_shape(self._batch_shape_parameter))

      super(Radial, self).__init__(
          dtype=dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=(tfp.distributions.FULLY_REPARAMETERIZED),
          parameters=parameters,
          name=name)
    def _batch_shape(self):
        # If there's a chance that the batch_shape has been overridden, we return
        # what we statically know about the `override_batch_shape`. This works
        # because: `_is_maybe_batch_override` means that the `constant_value()` of
        # `override_batch_shape` is `None` or a non-empty list, i.e., we don't
        # statically know the `batch_shape` or we do.
        #
        # Notice that this implementation parallels the `_event_shape` except that
        # the `bijector` doesn't get to alter the `batch_shape`. Recall that
        # `batch_shape` is a property of a distribution while `event_shape` is
        # shared between both the `distribution` instance and the `bijector`.
        if self._is_maybe_batch_override:
            return tensorshape_util.constant_value_as_shape(
                self._override_batch_shape)

        # As with `batch_shape_tensor`, if the base distribution is joint with
        # structured batch shape and the transformed distribution is not joint,
        # the batch shape components of the base distribution are broadcast to
        # obtain the batch shape of the transformed distribution.
        batch_shape = self.distribution.batch_shape
        if tf.nest.is_nested(batch_shape) and not self._is_joint:
            batch_shape = functools.reduce(tf.broadcast_static_shape,
                                           tf.nest.flatten(batch_shape))
        return batch_shape
示例#17
0
 def _event_shape(self):
     return tensorshape_util.constant_value_as_shape(
         tf.expand_dims(self._k, axis=0))
示例#18
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init:
      axis_ = tf.get_static_value(self._axis)
      if axis_ is not None and axis_ < 0:
        raise ValueError('Axis should be positive, %d was given' % axis_)
      if axis_ is None:
        assertions.append(tf.assert_greater_equal(axis_, 0))

      all_event_shapes = [d.event_shape for d in self._distributions]
      if all(tensorshape_util.is_fully_defined(event_shape)
             for event_shape in all_event_shapes):
        if all_event_shapes[1:] != all_event_shapes[:-1]:
          raise ValueError('Distributions must have the same `event_shape`;'
                           'found: {}' % all_event_shapes)

      all_batch_shapes = [d.batch_shape for d in self._distributions]
      if all(tensorshape_util.is_fully_defined(batch_shape)
             for batch_shape in all_batch_shapes):
        batch_shape = all_batch_shapes[0].as_list()
        batch_shape[self._axis] = 1
        for b in all_batch_shapes[1:]:
          b = b.as_list()
          if len(batch_shape) != len(b):
            raise ValueError('Incompatible batch shape % s with %s' %
                             (batch_shape, b))
          b[self._axis] = 1
          tf.broadcast_static_shape(
              tensorshape_util.constant_value_as_shape(batch_shape),
              tensorshape_util.constant_value_as_shape(b))

    if not self.validate_args:
      return []

    if self.validate_args:
      # Validate that event shapes all match.
      all_event_shapes = [d.event_shape for d in self._distributions]
      if not all(tensorshape_util.is_fully_defined(event_shape)
                 for event_shape in all_event_shapes):
        all_event_shape_tensors = [d.event_shape_tensor() for
                                   d in self._distributions]
        def _get_shapes(static_shape, dynamic_shape):
          if tensorshape_util.is_fully_defined(static_shape):
            return static_shape
          else:
            return dynamic_shape
        event_shapes = tf.nest.map_structure(_get_shapes,
                                             all_event_shapes,
                                             all_event_shape_tensors)
        event_shapes = tf.nest.flatten(event_shapes)
        assertions.extend(
            assert_util.assert_equal(
                e1, e2, message='Distributions should have same event shapes.')
            for e1, e2 in zip(event_shapes[1:], event_shapes[:-1]))

      # Validate that batch shapes are broadcastable and concatenable along
      # the specified axis.
      if not all(tensorshape_util.is_fully_defined(d.batch_shape)
                 for d in self._distributions):
        for i, d in enumerate(self._distributions[:-1]):
          assertions.append(tf.assert_equal(
              tf.size(d.batch_shape_tensor()),
              tf.size(self._distributions[i+1].batch_shape_tensor())))

        batch_shape_tensors = [
            ps.tensor_scatter_nd_update(d.batch_shape_tensor(), updates=1,
                                        indices=[self._axis])
            for d in self._distributions
        ]
        assertions.append(
            functools.reduce(tf.broadcast_dynamic_shape,
                             batch_shape_tensors[1:],
                             batch_shape_tensors[:-1]))
    return assertions
示例#19
0
  def test_transform_joint_to_joint(self, split_sizes):
    dist_batch_shape = tf.nest.pack_sequence_as(
        split_sizes,
        [tensorshape_util.constant_value_as_shape(s)
         for s in [[2, 3], [2, 1], [1, 3]]])
    bijector_batch_shape = [1, 3]

    # Build a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.exp(
                tf.random.normal(batch_shape + [size], seed=seed()))),
        split_sizes, dist_batch_shape)
    if isinstance(split_sizes, dict):
      base_dist = tfd.JointDistributionNamed(component_dists)
    else:
      base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform the distribution by applying a separate bijector to each part.
    bijectors = [tfb.Exp(),
                 tfb.Scale(tf.random.normal(bijector_batch_shape, seed=seed())),
                 tfb.Reshape([2, 1])]
    bijector = ToyZipMap(tf.nest.pack_sequence_as(split_sizes, bijectors))

    with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'):
      tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3])

    with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'):
      tfd.TransformedDistribution(base_dist, bijector, event_shape=[3])

    # Transform a joint distribution that has different batch shape components
    transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

    self.assertAllEqualNested(
        transformed_dist.event_shape,
        bijector.forward_event_shape(base_dist.event_shape))
    self.assertAllEqualNested(*self.evaluate((
        transformed_dist.event_shape_tensor(),
        bijector.forward_event_shape_tensor(base_dist.event_shape_tensor()))))

    # Test that the batch shape components of the input are the same as those of
    # the output.
    self.assertAllEqualNested(transformed_dist.batch_shape, dist_batch_shape)
    self.assertAllEqualNested(
        self.evaluate(transformed_dist.batch_shape_tensor()), dist_batch_shape)
    self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)

    # Check transformed `log_prob` against the base distribution.
    sample_shape = [3]
    sample = base_dist.sample(sample_shape, seed=seed())
    x = tf.nest.map_structure(tf.zeros_like, sample)
    y = bijector.forward(x)
    base_logprob = base_dist.log_prob(x)
    event_ndims = tf.nest.map_structure(lambda s: s.ndims,
                                        transformed_dist.event_shape)
    ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)

    (transformed_logprob,
     base_logprob_plus_ildj,
     log_transformed_prob
    ) = self.evaluate([
        transformed_dist.log_prob(y),
        base_logprob + ildj,
        tf.math.log(transformed_dist.prob(y))
    ])
    self.assertAllClose(base_logprob_plus_ildj, transformed_logprob)
    self.assertAllClose(transformed_logprob, log_transformed_prob)

    # Test that `.sample()` works and returns a result of the expected structure
    # and shape.
    y_sampled = transformed_dist.sample(sample_shape, seed=seed())
    self.assertAllEqual(tf.nest.map_structure(lambda y: y.shape, y),
                        tf.nest.map_structure(lambda y: y.shape, y_sampled))
示例#20
0
 def _batch_shape(self):
   return tensorshape_util.constant_value_as_shape(
       self._calculate_batch_shape())
    def test_transform_joint_to_joint(self, split_sizes):
        dist_batch_shape = tf.nest.pack_sequence_as(split_sizes, [
            tensorshape_util.constant_value_as_shape(s)
            for s in [[2, 3], [2, 1], [1, 3]]
        ])
        bijector_batch_shape = [1, 3]

        # Build a joint distribution with parts of the specified sizes.
        seed = test_util.test_seed_stream()
        component_dists = tf.nest.map_structure(
            lambda size, batch_shape: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
                loc=tf.random.normal(batch_shape + [size], seed=seed()),
                scale_diag=tf.random.uniform(minval=1.,
                                             maxval=2.,
                                             shape=batch_shape + [size],
                                             seed=seed())),
            split_sizes,
            dist_batch_shape)
        if isinstance(split_sizes, dict):
            base_dist = tfd.JointDistributionNamed(component_dists)
        else:
            base_dist = tfd.JointDistributionSequential(component_dists)

        # Transform the distribution by applying a separate bijector to each part.
        bijectors = [
            tfb.Exp(),
            tfb.Scale(
                tf.random.uniform(minval=1.,
                                  maxval=2.,
                                  shape=bijector_batch_shape,
                                  seed=seed())),
            tfb.Reshape([2, 1])
        ]
        bijector = tfb.JointMap(tf.nest.pack_sequence_as(
            split_sizes, bijectors),
                                validate_args=True)

        # Transform a joint distribution that has different batch shape components
        transformed_dist = tfd.TransformedDistribution(base_dist, bijector)

        self.assertRegex(
            str(transformed_dist),
            '{}.*batch_shape.*event_shape.*dtype'.format(
                transformed_dist.name))

        self.assertAllEqualNested(
            transformed_dist.event_shape,
            bijector.forward_event_shape(base_dist.event_shape))
        self.assertAllEqualNested(
            *self.evaluate((transformed_dist.event_shape_tensor(),
                            bijector.forward_event_shape_tensor(
                                base_dist.event_shape_tensor()))))

        # Test that the batch shape components of the input are the same as those of
        # the output.
        self.assertAllEqualNested(transformed_dist.batch_shape,
                                  dist_batch_shape)
        self.assertAllEqualNested(
            self.evaluate(transformed_dist.batch_shape_tensor()),
            dist_batch_shape)
        self.assertAllEqualNested(dist_batch_shape, base_dist.batch_shape)

        # Check transformed `log_prob` against the base distribution.
        sample_shape = [3]
        sample = base_dist.sample(sample_shape, seed=seed())
        x = tf.nest.map_structure(tf.zeros_like, sample)
        y = bijector.forward(x)
        base_logprob = base_dist.log_prob(x)
        event_ndims = tf.nest.map_structure(lambda s: s.ndims,
                                            transformed_dist.event_shape)
        ildj = bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)

        (transformed_logprob, base_logprob_plus_ildj,
         log_transformed_prob) = self.evaluate([
             transformed_dist.log_prob(y), base_logprob + ildj,
             tf.math.log(transformed_dist.prob(y))
         ])
        self.assertAllClose(base_logprob_plus_ildj, transformed_logprob)
        self.assertAllClose(transformed_logprob, log_transformed_prob)

        # Test that `.sample()` works and returns a result of the expected structure
        # and shape.
        y_sampled = transformed_dist.sample(sample_shape, seed=seed())
        self.assertAllEqual(
            tf.nest.map_structure(lambda y: y.shape, y),
            tf.nest.map_structure(lambda y: y.shape, y_sampled))

        # Test that a `Restructure` bijector applied to a `JointDistribution` works
        # as expected.
        num_components = len(split_sizes)
        input_keys = (split_sizes.keys() if isinstance(split_sizes, dict) else
                      range(num_components))
        output_keys = [str(i) for i in range(num_components)]
        output_structure = {k: v for k, v in zip(output_keys, input_keys)}
        restructure = tfb.Restructure(output_structure)
        restructured_dist = tfd.TransformedDistribution(base_dist,
                                                        bijector=restructure,
                                                        validate_args=True)

        # Check that attributes of the restructured distribution have the same
        # nested structure as the `output_structure` of the bijector. Pass a no-op
        # as the `assert_fn` since the contents of the structures are not
        # required to be the same.
        noop_assert_fn = lambda *_: None
        self.assertAllAssertsNested(noop_assert_fn,
                                    restructured_dist.event_shape,
                                    output_structure)
        self.assertAllAssertsNested(noop_assert_fn,
                                    restructured_dist.batch_shape,
                                    output_structure)
        self.assertAllAssertsNested(
            noop_assert_fn,
            self.evaluate(restructured_dist.event_shape_tensor()),
            output_structure)
        self.assertAllAssertsNested(
            noop_assert_fn,
            self.evaluate(restructured_dist.batch_shape_tensor()),
            output_structure)
        self.assertAllAssertsNested(
            noop_assert_fn,
            self.evaluate(
                restructured_dist.sample(seed=test_util.test_seed())))
示例#22
0
def _replace_event_shape_in_shape_tensor(
    input_shape, event_shape_in, event_shape_out, validate_args):
  """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
  output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
      tensorshape_util.constant_value_as_shape(input_shape),
      event_shape_in,
      event_shape_out)

  # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
  # correctly supports control_dependencies.
  validation_dependencies = (
      map(tf.identity, (event_shape_in, event_shape_out))
      if validate_args else ())

  if (tensorshape_util.is_fully_defined(output_tensorshape) and
      (is_validated or not validate_args)):
    with tf.control_dependencies(validation_dependencies):
      output_shape = tf.convert_to_tensor(
          output_tensorshape, name='output_shape', dtype_hint=tf.int32)
    return output_shape, output_tensorshape

  with tf.control_dependencies(validation_dependencies):
    event_shape_in_ndims = (
        tf.size(event_shape_in)
        if tensorshape_util.num_elements(event_shape_in.shape) is None else
        tensorshape_util.num_elements(event_shape_in.shape))
    input_non_event_shape, input_event_shape = tf.split(
        input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

  additional_assertions = []
  if is_validated:
    pass
  elif validate_args:
    # Check that `input_event_shape` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    mask = event_shape_in >= 0
    explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask)
    explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
    additional_assertions.append(
        assert_util.assert_equal(
            explicit_input_event_shape,
            explicit_event_shape_in,
            message='Input `event_shape` does not match `event_shape_in`.'))
    # We don't explicitly additionally verify
    # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
    # already makes this assertion.

  with tf.control_dependencies(additional_assertions):
    output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0,
                             name='output_shape')

  return output_shape, output_tensorshape
示例#23
0
def _replace_event_shape_in_tensorshape(
    input_tensorshape, event_shape_in, event_shape_out):
  """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
  event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
  if tensorshape_util.rank(
      input_tensorshape) is None or event_shape_in_ndims is None:
    return tf.TensorShape(None), False  # Not is_validated.

  input_non_event_ndims = tensorshape_util.rank(
      input_tensorshape) - event_shape_in_ndims
  if input_non_event_ndims < 0:
    raise ValueError(
        'Input has fewer ndims ({}) than event shape ndims ({}).'.format(
            tensorshape_util.rank(input_tensorshape), event_shape_in_ndims))

  input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
  input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

  # Check that `input_event_shape_` and `event_shape_in` are compatible in the
  # sense that they have equal entries in any position that isn't a `-1` in
  # `event_shape_in`. Note that our validations at construction time ensure
  # there is at most one such entry in `event_shape_in`.
  event_shape_in_ = tf.get_static_value(event_shape_in)
  is_validated = (
      tensorshape_util.is_fully_defined(input_event_tensorshape) and
      event_shape_in_ is not None)
  if is_validated:
    input_event_shape_ = np.int32(input_event_tensorshape)
    mask = event_shape_in_ >= 0
    explicit_input_event_shape_ = input_event_shape_[mask]
    explicit_event_shape_in_ = event_shape_in_[mask]
    if not all(explicit_input_event_shape_ == explicit_event_shape_in_):
      raise ValueError(
          'Input `event_shape` does not match `event_shape_in`. '
          '({} vs {}).'.format(input_event_shape_, event_shape_in_))

  event_tensorshape_out = tensorshape_util.constant_value_as_shape(
      event_shape_out)
  if tensorshape_util.rank(event_tensorshape_out) is None:
    output_tensorshape = tf.TensorShape(None)
  else:
    output_tensorshape = tensorshape_util.concatenate(
        input_non_event_tensorshape, event_tensorshape_out)

  return output_tensorshape, is_validated
示例#24
0
    def __init__(self,
                 image_shape,
                 conditional_shape=None,
                 num_resnet=5,
                 num_hierarchies=3,
                 num_filters=160,
                 num_logistic_mix=10,
                 receptive_field_dims=(3, 3),
                 dropout_p=0.5,
                 resnet_activation='concat_elu',
                 use_weight_norm=True,
                 use_data_init=True,
                 high=255,
                 low=0,
                 dtype=tf.float32,
                 name='PixelCNN'):
        """Construct Pixel CNN++ distribution.

    Args:
      image_shape: 3D `TensorShape` or tuple for the `[height, width, channels]`
        dimensions of the image.
      conditional_shape: `TensorShape` or tuple for the shape of the
        conditional input, or `None` if there is no conditional input.
      num_resnet: `int`, the number of layers (shown in Figure 2 of [2]) within
        each highest-level block of Figure 2 of [1].
      num_hierarchies: `int`, the number of hightest-level blocks (separated by
        expansions/contractions of dimensions in Figure 2 of [1].)
      num_filters: `int`, the number of convolutional filters.
      num_logistic_mix: `int`, number of components in the logistic mixture
        distribution.
      receptive_field_dims: `tuple`, height and width in pixels of the receptive
        field of the convolutional layers above and to the left of a given
        pixel. The width (second element of the tuple) should be odd. Figure 1
        (middle) of [2] shows a receptive field of (3, 5) (the row containing
        the current pixel is included in the height). The default of (3, 3) was
        used to produce the results in [1].
      dropout_p: `float`, the dropout probability. Should be between 0 and 1.
      resnet_activation: `string`, the type of activation to use in the resnet
        blocks. May be 'concat_elu', 'elu', or 'relu'.
      use_weight_norm: `bool`, if `True` then use weight normalization (works
        only in Eager mode).
      use_data_init: `bool`, if `True` then use data-dependent initialization
        (has no effect if `use_weight_norm` is `False`).
      high: `int`, the maximum value of the input data (255 for an 8-bit image).
      low: `int`, the minimum value of the input data.
      dtype: Data type of the `Distribution`.
      name: `string`, the name of the `Distribution`.
    """

        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(PixelCNN, self).__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=False,
                allow_nan_stats=True,
                parameters=parameters,
                name=name)

            if not tensorshape_util.is_fully_defined(image_shape):
                raise ValueError('`image_shape` must be fully defined.')

            if (conditional_shape is not None and
                    not tensorshape_util.is_fully_defined(conditional_shape)):
                raise ValueError('`conditional_shape` must be fully defined`')

            if tensorshape_util.rank(image_shape) != 3:
                raise ValueError(
                    '`image_shape` must have length 3, representing '
                    '[height, width, channels] dimensions.')

            self._high = tf.cast(high, self.dtype)
            self._low = tf.cast(low, self.dtype)
            self._num_logistic_mix = num_logistic_mix
            self.network = _PixelCNNNetwork(
                dropout_p=dropout_p,
                num_resnet=num_resnet,
                num_hierarchies=num_hierarchies,
                num_filters=num_filters,
                num_logistic_mix=num_logistic_mix,
                receptive_field_dims=receptive_field_dims,
                resnet_activation=resnet_activation,
                use_weight_norm=use_weight_norm,
                use_data_init=use_data_init,
                dtype=dtype)

            image_shape = tensorshape_util.constant_value_as_shape(image_shape)
            conditional_shape = (
                None if conditional_shape is None else
                tensorshape_util.constant_value_as_shape(conditional_shape))

            image_input_shape = tensorshape_util.concatenate([None],
                                                             image_shape)
            if conditional_shape is None:
                input_shape = image_input_shape
            else:
                conditional_input_shape = tensorshape_util.concatenate(
                    [None], conditional_shape)
                input_shape = [image_input_shape, conditional_input_shape]

            self.image_shape = image_shape
            self.conditional_shape = conditional_shape
            self.network.build(input_shape)