Exemple #1
0
def _kl_blockwise_blockwise(b0, b1, name=None):
  """Calculate the batched KL divergence KL(b0 || b1) with b0 and b1 Blockwise distributions.

  Args:
    b0: instance of a Blockwise distribution object.
    b1: instance of a Blockwise distribution object.
    name: (optional) Name to use for created operations. Default is
      "kl_blockwise_blockwise".

  Returns:
    kl_blockwise_blockwise: `Tensor`. The batchwise KL(b0 || b1).
  """
  if len(b0.distributions) != len(b1.distributions):
    raise ValueError(
        'Can only compute KL divergence between Blockwise distributions with '
        'the same number of component distributions.')

  # We also need to check that the event shapes match for each one.
  b0_event_sizes = [_event_size(d) for d in b0.distributions]
  b1_event_sizes = [_event_size(d) for d in b1.distributions]

  assertions = []
  message = ('Can only compute KL divergence between Blockwise distributions '
             'with the same pairwise event shapes.')

  if (all(isinstance(event_size, int) for event_size in b0_event_sizes) and
      all(isinstance(event_size, int) for event_size in b1_event_sizes)):
    if b0_event_sizes != b1_event_sizes:
      raise ValueError(message)
  else:
    if b0.validate_args or b1.validate_args:
      assertions.extend(
          assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
              e1, e2, message=message)
          for e1, e2 in zip(b0_event_sizes, b1_event_sizes))

  with tf.name_scope(name or 'kl_blockwise_blockwise'):
    with tf.control_dependencies(assertions):
      return sum([
          kullback_leibler.kl_divergence(d1, d2) for d1, d2 in zip(
              b0.distributions, b1.distributions)])
Exemple #2
0
def _kl_sample(a, b, name='kl_sample'):
    """Batched KL divergence `KL(a || b)` for Sample distributions.

  We can leverage the fact that:

  ```
  KL(Sample(a) || Sample(b)) = sum(KL(a || b))
  ```

  where the sum is over the `sample_shape` dims.

  Args:
    a: Instance of `Sample` distribution.
    b: Instance of `Sample` distribution.
    name: (optional) name to use for created ops.
      Default value: `"kl_sample"`'.

  Returns:
    kldiv: Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the `sample_shape` of `a` and `b` don't match.
  """
    assertions = []
    a_ss = tf.get_static_value(a.sample_shape)
    b_ss = tf.get_static_value(b.sample_shape)
    msg = '`a.sample_shape` must be identical to `b.sample_shape`.'
    if a_ss is not None and b_ss is not None:
        if not np.array_equal(a_ss, b_ss):
            raise ValueError(msg)
    elif a.validate_args or b.validate_args:
        assertions.append(
            assert_util.assert_equal(a.sample_shape,
                                     b.sample_shape,
                                     message=msg))
    with tf.control_dependencies(assertions):
        kl = kullback_leibler.kl_divergence(a.distribution,
                                            b.distribution,
                                            name=name)
        n = ps.reduce_prod(a.sample_shape)
        return tf.cast(x=n, dtype=kl.dtype) * kl
Exemple #3
0
def _assert_batch_shape_matches_weights(distribution, weights_shape, diststr):
    """Checks that all parts of a distribution have the expected batch shape."""
    shapes = [weights_shape] + tf.nest.flatten(
        distribution.batch_shape_tensor())
    static_shapes = [
        tf.get_static_value(ps.convert_to_shape_tensor(s)) for s in shapes
    ]
    static_shapes_not_none = [s for s in static_shapes if s is not None]
    static_shapes_match = all([
        np.all(a == b)  # Also need to check for rank mismatch (below).
        for (a,
             b) in zip(static_shapes_not_none[1:], static_shapes_not_none[:-1])
    ])

    # Build a separate list of static ranks, since rank is often static even when
    # shape is not.
    ranks = [ps.rank_from_shape(s) for s in shapes]
    static_ranks = [int(r) for r in ranks if not tf.is_tensor(r)]
    static_ranks_match = all(
        [a == b for (a, b) in zip(static_ranks[1:], static_ranks[:-1])])

    msg = (
        "The {diststr} distribution's batch shape does not match the particle "
        "weights; a correct {diststr} distribution must return an independent "
        "log-density for each particle. You may be "
        "creating a joint distribution in which some parts do not depend on the "
        "previous particles, and/or you are creating an autobatched joint "
        "distribution without setting `batch_ndims`.".format(diststr=diststr))
    if not (static_ranks_match and static_shapes_match):
        raise ValueError(
            msg + ' ' +
            'Weights have shape {}, but the distribution has batch '
            'shape {}.'.format(weights_shape, distribution.batch_shape))

    assertions = []
    if distribution.validate_args and any([s is None for s in static_shapes]):
        assertions = [
            assert_util.assert_equal(a, b, message=msg)
            for a, b in zip(shapes[1:], shapes[:-1])
        ]
    return assertions
Exemple #4
0
    def _sample_control_dependencies(self, samples):
        inner_sample_dim = samples.shape[-1]
        shape_msg = ('Samples must have innermost dimension matching that of '
                     '`self.dimension`. Found {}, expected {}'.format(
                         inner_sample_dim, self.dimension))
        if inner_sample_dim is not None:
            if self.dimension != inner_sample_dim:
                raise ValueError(shape_msg)

        assertions = []
        if not self.validate_args:
            return assertions
        assertions.append(
            assert_util.assert_near(tf.cast(1., dtype=self.dtype),
                                    tf.linalg.norm(samples, axis=-1),
                                    message='Samples must be unit length.'))
        assertions.append(
            assert_util.assert_equal(tf.shape(samples)[-1:],
                                     self.dimension,
                                     message=shape_msg))
        return assertions
Exemple #5
0
 def _sample_control_dependencies(self, x):
     assertions = []
     if not self.validate_args:
         return assertions
     loc = tf.convert_to_tensor(self.loc)
     scale = tf.convert_to_tensor(self.scale)
     concentration = tf.convert_to_tensor(self.concentration)
     assertions.append(
         assert_util.assert_greater_equal(
             x,
             loc,
             message='Sample must be greater than or equal to `loc`.'))
     assertions.append(
         assert_util.assert_equal(
             tf.logical_or(tf.greater_equal(concentration, 0),
                           tf.less_equal(x, loc - scale / concentration)),
             True,
             message=('If `concentration < 0`, sample must be less than or '
                      'equal to `loc - scale / concentration`.'),
             summarize=100))
     return assertions
Exemple #6
0
  def _parameter_control_dependencies(self, is_init):
    if tensorshape_util.is_fully_defined(self.distribution.batch_shape):
      if self.to_shape is not None:
        static_to_shape = tf.get_static_value(self.to_shape)
        if static_to_shape is not None:
          bcast_shp = tf.broadcast_static_shape(
              tf.TensorShape(static_to_shape),
              self.distribution.batch_shape)
          if bcast_shp != static_to_shape:
            raise ValueError(f'Argument `to_shape` ({static_to_shape}) '
                             'is incompatible with underlying distribution '
                             f'batch shape ({self.distribution.batch_shape}).')

      else:
        static_with_shape = tf.get_static_value(self.with_shape)
        if static_with_shape is not None:
          tf.broadcast_static_shape(  # Ensure compatible.
              tf.TensorShape(static_with_shape),
              self.distribution.batch_shape)

    underlying = self.distribution._parameter_control_dependencies(is_init)  # pylint: disable=protected-access
    if not self.validate_args:
      return underlying

    checks = []
    if self.to_shape is not None:
      if tensor_util.is_ref(self.to_shape) != is_init:
        checks += [assert_util.assert_equal(
            self.to_shape,
            ps.broadcast_shape(self.distribution.batch_shape_tensor(),
                               self.to_shape),
            message='Argument `to_shape` is incompatible with underlying '
                    'distribution batch shape.')]
    else:
      if tensor_util.is_ref(self.with_shape) != is_init:
        checks += [tf.broadcast_dynamic_shape(
            self.distribution.batch_shape_tensor(),
            self.with_shape)]

    return tuple(checks) + tuple(underlying)
def _kl_power_uniform_spherical(a, b, name=None):
    """Calculate the batched KL divergence KL(a || b).

  Args:
    a: instance of a PowerSpherical distribution object.
    b: instance of a SphericalUniform distribution object.
    name: (optional) Name to use for created operations.
      default is "kl_power_uniform_spherical".

  Returns:
    Batchwise KL(a || b)

  Raises:
    ValueError: If the two distributions are over spheres of different
      dimensions.

  #### References

  [1] Nicola de Cao, Wilker Aziz. The Power Spherical distribution.
      https://arxiv.org/abs/2006.04437.
  """
    with tf.name_scope(name or 'kl_power_uniform_spherical'):
        msg = (
            'Can not compute the KL divergence between a `PowerSpherical` and '
            '`SphericalUniform` of different dimensions.')
        deps = []
        if a.event_shape[-1] is not None:
            if a.event_shape[-1] != b.dimension:
                raise ValueError(
                    (msg + 'Got {} vs. {}').format(a.event_shape[-1],
                                                   b.dimension))
        elif a.validate_args or b.validate_args:
            deps += [
                assert_util.assert_equal(a.event_shape_tensor()[-1],
                                         b.dimension,
                                         message=msg)
            ]

        with tf.control_dependencies(deps):
            return b.entropy() - a.entropy()
Exemple #8
0
 def _inverse(self, y):
   ndims = ps.rank(y)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   num_left, num_right = ps.unstack(self.paddings, num=2, axis=-1)
   x = tf.slice(
       y,
       begin=ps.tensor_scatter_nd_update(
           ps.zeros(ndims, dtype=tf.int32),
           indices, num_left),
       size=ps.tensor_scatter_nd_sub(
           ps.shape(y),
           indices, num_left + num_right))
   if not self.validate_args:
     return x
   assertions = [
       assert_util.assert_equal(
           self._forward(x), y,
           message=('Argument `y` to `inverse` was not padded with '
                    '`constant_values`.')),
   ]
   with tf.control_dependencies(assertions):
     return tf.identity(x)
Exemple #9
0
  def _sample_control_dependencies(self, samples):
    """Check samples for proper shape and whether samples are unit vectors."""
    inner_sample_dim = samples.shape[-1]
    event_size = self.event_shape[-1]
    shape_msg = ('Samples must have innermost dimension matching that of '
                 '`self.mean_direction`.')
    if event_size is not None and inner_sample_dim is not None:
      if event_size != inner_sample_dim:
        raise ValueError(shape_msg)

    assertions = []
    if not self.validate_args:
      return assertions
    assertions.append(assert_util.assert_near(
        1.,
        tf.linalg.norm(samples, axis=-1),
        message='Samples must be unit length.'))
    assertions.append(assert_util.assert_equal(
        tf.shape(samples)[-1:],
        self.event_shape_tensor(),
        message=shape_msg))
    return assertions
def vector_size_to_square_matrix_size(d, validate_args, name=None):
    """Convert a vector size to a matrix size."""
    if isinstance(d, (float, int, np.generic, np.ndarray)):
        n = (-1 + np.sqrt(1 + 8 * d)) / 2.
        if float(int(n)) != n:
            raise ValueError(
                'Vector length {} is not a triangular number.'.format(d))
        return int(n)
    else:
        with tf.name_scope(name
                           or 'vector_size_to_square_matrix_size') as name:
            n = (-1. + tf.sqrt(1 + 8. * tf.cast(d, dtype=tf.float32))) / 2.
            if validate_args:
                with tf.control_dependencies([
                        assert_util.assert_equal(
                            tf.cast(tf.cast(n, dtype=tf.int32),
                                    dtype=tf.float32),
                            n,
                            data=[d],
                            message='Vector length is not a triangular number')
                ]):
                    n = tf.identity(n)
            return tf.cast(n, d.dtype)
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     # Avoid computing intermediates needed to construct the assertions.
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self._batch_shape_unexpanded):
     implicit_dim_mask = prefer_static.equal(self._batch_shape_unexpanded, -1)
     assertions.append(assert_util.assert_rank(
         self._batch_shape_unexpanded, 1,
         message='New shape must be a vector.'))
     assertions.append(assert_util.assert_less_equal(
         tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1,
         message='At most one dimension can be unknown.'))
     assertions.append(assert_util.assert_non_negative(
         self._batch_shape_unexpanded + 1,
         message='Shape elements must be >=-1.'))
     # Check that the old and new shapes are the same size.
     expanded_new_shape, original_size = self._calculate_new_shape()
     new_size = prefer_static.reduce_prod(expanded_new_shape)
     assertions.append(assert_util.assert_equal(
         new_size, tf.cast(original_size, new_size.dtype),
         message='Shape sizes do not match.'))
   return assertions
Exemple #12
0
  def _maybe_warn_increased_dof(self,
                                component_name,
                                component_ldj,
                                increased_dof):
    """Warns or raises when `increased_dof` is True."""
    # Short-circuit when the component LDJ is statically zero.
    if (tf.get_static_value(tf.rank(component_ldj)) == 0
        and tf.get_static_value(component_ldj) == 0):
      return

    # Short-circuit when increased_dof is statically False.
    increased_dof_ = tf.get_static_value(increased_dof)
    if increased_dof_ is False:  # pylint: disable=g-bool-id-comparison
      return

    error_message = (
        'Nested component "{}" in composition "{}" operates on inputs '
        'with increased degrees of freedom. This may result in an '
        'incorrect log_det_jacobian.'
        ).format(component_name, self.name)

    # When validate_args is True, we raise on increased DoF.
    if self._validate_args:
      if increased_dof_:
        raise ValueError(error_message)
      return assert_util.assert_equal(False, increased_dof, error_message)

    if (not tf.executing_eagerly() and
        control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())):
      return  # No StringFormat or Print ops in XLA.

    # Otherwise, we print a warning and continue.
    return ps.cond(
        pred=increased_dof,
        false_fn=tf.no_op,
        true_fn=lambda: tf.print(  # pylint: disable=g-long-lambda
            'WARNING: ' + error_message, output_stream=sys.stderr))
Exemple #13
0
def _maybe_validate_rightmost_transposed_ndims(
    initial_rightmost_transposed_ndims,
    rightmost_transposed_ndims, validate_args, name=None):
  """Checks that `rightmost_transposed_ndims` is valid."""
  with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'):
    assertions = []

    if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None:
      if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0:
        raise ValueError('`rightmost_transposed_ndims` must be a scalar, '
                         'saw rank: {}.'.format(
                             tensorshape_util.rank(
                                 rightmost_transposed_ndims.shape)))
    elif validate_args:
      assertions += [
          assert_util.assert_rank(rightmost_transposed_ndims, 0),
          assert_util.assert_equal(
              rightmost_transposed_ndims,
              initial_rightmost_transposed_ndims,
              message='`rightmost_transposed_ndims` must not change '
                      'from the value set when the `Transpose` '
                      'bijector was constructed.')]

    rightmost_transposed_ndims_ = tf.get_static_value(
        rightmost_transposed_ndims)
    msg = '`rightmost_transposed_ndims` must be non-negative.'
    if rightmost_transposed_ndims_ is not None:
      if rightmost_transposed_ndims_ < 0:
        raise ValueError(msg[:-1] + ', saw: {}.'.format(
            rightmost_transposed_ndims_))
    elif validate_args:
      assertions += [
          assert_util.assert_non_negative(
              rightmost_transposed_ndims, message=msg)
      ]

    return assertions
Exemple #14
0
def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
    """Returns list of assertions related to `lu_solve` assumptions."""
    assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args)

    message = 'Input `rhs` must have at least 2 dimensions.'
    if rhs.shape.ndims is not None:
        if rhs.shape.ndims < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(rhs, rank=2, message=message))

    message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
    if (tf.compat.dimension_value(lower_upper.shape[-1]) is not None
            and tf.compat.dimension_value(rhs.shape[-2]) is not None):
        if lower_upper.shape[-1] != rhs.shape[-2]:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(tf.shape(lower_upper)[-1],
                                     tf.shape(rhs)[-2],
                                     message=message))

    return assertions
Exemple #15
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init != tensor_util.is_ref(self.permutation):
      if not dtype_util.is_integer(self.permutation.dtype):
        raise TypeError('permutation.dtype ({}) should be `int`-like.'.format(
            dtype_util.name(self.permutation.dtype)))

      p = tf.get_static_value(self.permutation)
      if p is not None:
        if set(p) != set(np.arange(p.size)):
          raise ValueError('Permutation over `d` must contain exactly one of '
                           'each of `{0, 1, ..., d}`.')

      if self.validate_args:
        p = tf.sort(self.permutation, axis=-1)
        assertions.append(
            assert_util.assert_equal(
                p,
                tf.range(tf.shape(p)[-1]),
                message=('Permutation over `d` must contain exactly one of '
                         'each of `{0, 1, ..., d}`.')))

    return assertions
Exemple #16
0
  def _parameter_control_dependencies(self, is_init):

    assertions = super(Wishart, self)._parameter_control_dependencies(is_init)

    if not self.validate_args:
      assert not assertions
      return []

    if self._scale_full is None:
      if is_init != tensor_util.is_ref(self._scale_tril):
        shape = prefer_static.shape(self._scale_tril)
        assertions.extend(
            [assert_util.assert_positive(
                tf.linalg.diag_part(self._scale_tril),
                message='`scale_tril` must be positive definite.'),
             assert_util.assert_equal(
                 shape[-1],
                 shape[-2],
                 message='`scale_tril` must be square.')]
            )
    else:
      if is_init != tensor_util.is_ref(self._scale_full):
        assertions.append(distribution_util.assert_symmetric(self._scale_full))
    return assertions
Exemple #17
0
    def _maybe_assert_valid_x(self,
                              x,
                              loc=None,
                              scale=None,
                              concentration=None):
        if not self.validate_args:
            return []
        loc = tf.convert_to_tensor(self.loc) if loc is None else loc
        scale = tf.convert_to_tensor(self.scale) if scale is None else scale
        concentration = (tf.convert_to_tensor(self.concentration)
                         if concentration is None else concentration)
        # The support of this bijector depends on the sign of concentration.
        is_in_bounds = tf.where(concentration > 0.,
                                x >= loc - scale / concentration,
                                x <= loc - scale / concentration)
        # For concentration 0, the domain is the whole line.
        is_in_bounds = is_in_bounds | tf.math.equal(concentration, 0.)

        return [
            assert_util.assert_equal(
                is_in_bounds,
                True,
                message='Forward transformation input must be inside domain.')
        ]
Exemple #18
0
def prepare_tuple_argument(arg, n, arg_name, validate_args=False):
    """Helper which processes `Tensor`s to tuples in standard form."""
    arg_size = ps.size(arg)
    arg_size_ = tf.get_static_value(arg_size)
    assertions = []
    if arg_size_ is not None:
        if arg_size_ not in (1, n):
            raise ValueError(
                'The size of `{}` must be equal to `1` or to the rank '
                'of the convolution (={}). Saw size = {}'.format(
                    arg_name, n, arg_size_))
    elif validate_args:
        assertions.append(
            assert_util.assert_equal(
                ps.logical_or(arg_size == 1, arg_size == n),
                True,
                message=
                ('The size of `{}` must be equal to `1` or to the rank of the '
                 'convolution (={})'.format(arg_name, n))))

    with tf.control_dependencies(assertions):
        arg = ps.broadcast_to(arg, shape=[n])
        arg = ps.unstack(arg, num=n)
        return arg
Exemple #19
0
def maybe_check_quadrature_param(param, name, validate_args):
  """Helper which checks validity of `loc` and `scale` init args."""
  with tf.name_scope("check_" + name):
    assertions = []
    if tensorshape_util.rank(param.shape) is not None:
      if tensorshape_util.rank(param.shape) == 0:
        raise ValueError("Mixing params must be a (batch of) vector; "
                         "{}.rank={} is not at least one.".format(
                             name, tensorshape_util.rank(param.shape)))
    elif validate_args:
      assertions.append(
          assert_util.assert_rank_at_least(
              param,
              1,
              message=("Mixing params must be a (batch of) vector; "
                       "{}.rank is not at least one.".format(name))))

    # TODO(jvdillon): Remove once we support k-mixtures.
    if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None:
      if tf.compat.dimension_value(param.shape[-1]) != 1:
        raise NotImplementedError("Currently only bimixtures are supported; "
                                  "{}.shape[-1]={} is not 1.".format(
                                      name,
                                      tf.compat.dimension_value(
                                          param.shape[-1])))
    elif validate_args:
      assertions.append(
          assert_util.assert_equal(
              tf.shape(input=param)[-1],
              1,
              message=("Currently only bimixtures are supported; "
                       "{}.shape[-1] is not 1.".format(name))))

    if assertions:
      return distribution_util.with_dependencies(assertions, param)
    return param
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar_fn,
                                       static_base_shape, is_init):
        """Helper which ensures override batch/event_shape are valid."""

        assertions = []
        concretized_shape = None

        # Check valid dtype
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = override_shape.dtype
            if dtype_ is None:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                dtype_ = concretized_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError('Shape override must be integer type; '
                                'saw {}.'.format(dtype_util.name(dtype_)))

        # Check non-negative elements
        if is_init != tensor_util.is_ref(override_shape):
            override_shape_ = tf.get_static_value(override_shape)
            msg = 'Shape override must have non-negative elements.'
            if override_shape_ is not None:
                if np.any(np.array(override_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, override_shape_))
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                assertions.append(
                    assert_util.assert_non_negative(concretized_shape,
                                                    message=msg))

        # Check valid shape
        override_ndims_ = tensorshape_util.rank(override_shape.shape)
        if is_init != (override_ndims_ is None):
            msg = 'Shape override must be a vector.'
            if override_ndims_ is not None:
                if override_ndims_ != 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                override_rank = tf.rank(concretized_shape)
                assertions.append(
                    assert_util.assert_equal(override_rank, 1, message=msg))

        static_base_rank = tensorshape_util.rank(static_base_shape)

        # Determine if the override shape is `[]` (static_override_dims == [0]),
        # in which case the base distribution may be nonscalar.
        static_override_dims = tensorshape_util.dims(override_shape.shape)

        if is_init != (static_base_rank is None
                       or static_override_dims is None):
            msg = 'Base distribution is not scalar.'
            if static_base_rank is not None and static_override_dims is not None:
                if static_base_rank != 0 and static_override_dims != [0]:
                    raise ValueError(msg)
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                override_is_empty = tf.logical_not(
                    self._has_nonzero_rank(concretized_shape))
                assertions.append(
                    assert_util.assert_equal(tf.logical_or(
                        base_is_scalar_fn(), override_is_empty),
                                             True,
                                             message=msg))
        return assertions
Exemple #21
0
    def _distributional_transform(self, x):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the "flattened" event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution

    Returns:
      Result of the distributional transform
    """

        if x.shape.ndims is None:
            # tf.nn.softmax raises an error when applied to inputs of undefined rank.
            raise ValueError(
                "Distributional transform does not support inputs of "
                "undefined rank.")

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message="`univariate_components` must have scalar event")
        ]):
            x_padded = self._pad_sample_dims(x)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, self._event_size]),
                    exclusive=True,
                    axis=-1),
                tf.shape(input=log_prob_x))  # [S, B, k, E]

            logits_mix_prob = distribution_utils.pad_mixture_dimensions(
                self.mixture_distribution.logits, self,
                self.mixture_distribution, self._event_ndims)  # [B, k, 1]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = x.shape.ndims - self._event_ndims
            posterior_weights_x = tf.nn.softmax(log_posterior_weights_x,
                                                axis=component_axis)
            return tf.reduce_sum(input_tensor=posterior_weights_x * cdf_x,
                                 axis=component_axis)
Exemple #22
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(input=s)
            self._event_size = tf.reduce_prod(input_tensor=s)

            if not mixture_distribution.dtype.is_integer:
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(mixture_distribution.dtype.name))

            if (mixture_distribution.event_shape.ndims is not None
                    and mixture_distribution.event_shape.ndims != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(
                            input=mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = components_distribution.batch_shape.with_rank_at_least(
                1)[:-1]
            if mdbs.is_fully_defined() and cdbs.is_fully_defined():
                if mdbs.ndims != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            km = tf.compat.dimension_value(
                mixture_distribution.logits.shape.with_rank_at_least(1)[-1])
            kc = tf.compat.dimension_value(
                components_distribution.batch_shape.with_rank_at_least(1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(input=mixture_distribution.logits)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(input=mixture_distribution.logits)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    self._mixture_distribution._graph_parents  # pylint: disable=protected-access
                    + self._components_distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    # Check num_steps is a scalar that's at least 1.
    if is_init != tensor_util.is_ref(self.num_steps):
      num_steps = tf.convert_to_tensor(self.num_steps)
      num_steps_ = tf.get_static_value(num_steps)
      if num_steps_ is not None:
        if np.ndim(num_steps_) != 0:
          raise ValueError(
              '`num_steps` must be a scalar but it has rank {}'.format(
                  np.ndim(num_steps_)))
        if num_steps_ < 1:
          raise ValueError('`num_steps` must be at least 1.')
      elif self.validate_args:
        message = '`num_steps` must be a scalar'
        assertions.append(
            assert_util.assert_rank_at_most(self.num_steps, 0, message=message))
        assertions.append(
            assert_util.assert_greater_equal(
                num_steps, 1,
                message='`num_steps` must be at least 1.'))

    # Check that the initial distribution has scalar events over the
    # integers.
    if is_init and not dtype_util.is_integer(self.initial_distribution.dtype):
      raise ValueError(
          '`initial_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.initial_distribution.dtype)))

    if tensorshape_util.rank(self.initial_distribution.event_shape) is not None:
      if tensorshape_util.rank(self.initial_distribution.event_shape) != 0:
        raise ValueError('`initial_distribution` must have scalar `event_dim`s')
    elif self.validate_args:
      assertions += [
          assert_util.assert_equal(
              ps.size(self.initial_distribution.event_shape_tensor()),
              0,
              message='`initial_distribution` must have scalar `event_dim`s'),
      ]

    # Check that the transition distribution is over the integers.
    if (is_init and
        not dtype_util.is_integer(self.transition_distribution.dtype)):
      raise ValueError(
          '`transition_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.transition_distribution.dtype)))

    # Check observations have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0:
      raise ValueError(
          "`observation_distribution` can't have scalar batches")

    # Check transitions have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0:
      raise ValueError(
          "`transition_distribution` can't have scalar batches")

    # Check compatibility of transition distribution and observation
    # distribution.
    tdbs = self.transition_distribution.batch_shape
    odbs = self.observation_distribution.batch_shape
    if (tensorshape_util.dims(tdbs) is not None and
        tf.compat.dimension_value(odbs[-1]) is not None):
      if (tf.compat.dimension_value(tdbs[-1]) !=
          tf.compat.dimension_value(odbs[-1])):
        raise ValueError(
            '`transition_distribution` and `observation_distribution` '
            'must agree on last dimension of batch size')
    elif self.validate_args:
      tdbs = self.transition_distribution.batch_shape_tensor()
      odbs = self.observation_distribution.batch_shape_tensor()
      transition_precondition = assert_util.assert_greater(
          ps.size(tdbs), 0,
          message=('`transition_distribution` can\'t have scalar '
                   'batches'))
      observation_precondition = assert_util.assert_greater(
          ps.size(odbs), 0,
          message=('`observation_distribution` can\'t have scalar '
                   'batches'))
      with tf.control_dependencies([
          transition_precondition,
          observation_precondition]):
        assertions += [
            assert_util.assert_equal(
                tdbs[-1],
                odbs[-1],
                message=('`transition_distribution` and '
                         '`observation_distribution` '
                         'must agree on last dimension of batch size'))]

    return assertions
Exemple #24
0
def _kl_independent(a, b, name='kl_independent'):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default 'kl_independent'.

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    'KL between Independents with different '
                    'event shapes not supported.')
        else:
            raise ValueError('Event shapes do not match.')
    else:
        p_event_shape_tensor = p.event_shape_tensor()
        q_event_shape_tensor = q.event_shape_tensor()
        # NOTE: We could optimize by passing the event_shape_tensor of p and q
        # to a.event_shape_tensor() and b.event_shape_tensor().
        a_event_shape_tensor = a.event_shape_tensor()
        b_event_shape_tensor = b.event_shape_tensor()
        with tf.control_dependencies([
                assert_util.assert_equal(a_event_shape_tensor,
                                         b_event_shape_tensor,
                                         message='Event shapes do not match.'),
                assert_util.assert_equal(p_event_shape_tensor,
                                         q_event_shape_tensor,
                                         message='Event shapes do not match.'),
        ]):
            num_reduce_dims = (prefer_static.rank_from_shape(
                a_event_shape_tensor, a.event_shape) -
                               prefer_static.rank_from_shape(
                                   p_event_shape_tensor, p.event_shape))
            reduce_dims = prefer_static.range(-num_reduce_dims, 0, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and not dtype_util.is_integer(
                self.mixture_distribution.dtype):
            raise ValueError(
                '`mixture_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.mixture_distribution.dtype)))

        if tensorshape_util.rank(
                self.mixture_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.mixture_distribution.event_shape) != 0:
                raise ValueError(
                    '`mixture_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.mixture_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`mixture_distribution` must have scalar `event_dim`s'),
            ]

        # pylint: disable=protected-access
        mixture_dist_param = (self.mixture_distribution._probs
                              if self.mixture_distribution._logits is None else
                              self.mixture_distribution._logits)
        km = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                1)[-1])
        kc = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(
                self.components_distribution.batch_shape, 1)[-1])
        component_bst = None
        if km is not None and kc is not None:
            if km != kc:
                raise ValueError(
                    '`mixture_distribution` components ({}) does not '
                    'equal `components_distribution.batch_shape[-1]` '
                    '({})'.format(km, kc))
        elif self.validate_args:
            if km is None:
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                km = tf.shape(mixture_dist_param)[-1]
            if kc is None:
                component_bst = self.components_distribution.batch_shape_tensor(
                )
                kc = component_bst[-1]
            assertions += [
                assert_util.assert_equal(
                    km,
                    kc,
                    message=(
                        '`mixture_distribution` components does not equal '
                        '`components_distribution.batch_shape[-1]`')),
            ]

        mdbs = self.mixture_distribution.batch_shape
        cdbs = tensorshape_util.with_rank_at_least(
            self.components_distribution.batch_shape, 1)[:-1]
        if (tensorshape_util.is_fully_defined(mdbs)
                and tensorshape_util.is_fully_defined(cdbs)):
            if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                raise ValueError(
                    '`mixture_distribution.batch_shape` (`{}`) is not '
                    'compatible with `components_distribution.batch_shape` '
                    '(`{}`)'.format(tensorshape_util.as_list(mdbs),
                                    tensorshape_util.as_list(cdbs)))
        elif self.validate_args:
            if not tensorshape_util.is_fully_defined(mdbs):
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                mdbs = tf.shape(mixture_dist_param)[:-1]
            if not tensorshape_util.is_fully_defined(cdbs):
                if component_bst is None:
                    component_bst = self.components_distribution.batch_shape_tensor(
                    )
                cdbs = component_bst[:-1]
            assertions += [
                assert_util.assert_equal(
                    distribution_utils.pick_vector(
                        tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs),
                    cdbs,
                    message=(
                        '`mixture_distribution.batch_shape` is not '
                        'compatible with `components_distribution.batch_shape`'
                    ))
            ]

        return assertions
    def _distributional_transform(self, x, event_shape):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the 'flattened' event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution
      event_shape: The event shape of this distribution

    Returns:
      Result of the distributional transform
    """

        if tensorshape_util.rank(x.shape) is None:
            # tf.math.softmax raises an error when applied to inputs of undefined
            # rank.
            raise ValueError(
                'Distributional transform does not support inputs of '
                'undefined rank.')

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message='`univariate_components` must have scalar event')
        ]):
            event_ndims = ps.rank_from_shape(event_shape)
            x_padded = self._pad_sample_dims(
                x, event_ndims=event_ndims)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, event_size]),
                    exclusive=True,
                    axis=-1),
                ps.shape(log_prob_x))  # [S, B, k, E]

            event_ndims = ps.rank_from_shape(event_shape)
            logits_mix_prob = self.mixture_distribution.logits_parameter()
            logits_mix_prob = tf.reshape(
                logits_mix_prob,  # [k] or [B, k]
                ps.concat([
                    ps.shape(logits_mix_prob),
                    ps.ones([event_ndims], dtype=tf.int32),
                ],
                          axis=0))  # [k, [1]*e] or [B, k, [1]*e]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = tensorshape_util.rank(x.shape) - event_ndims
            posterior_weights_x = tf.math.softmax(log_posterior_weights_x,
                                                  axis=component_axis)
            return tf.reduce_sum(posterior_weights_x * cdf_x,
                                 axis=component_axis)
  def _sample_control_dependencies(self, x):
    """Helper which validates sample arg, e.g., input to `log_prob`."""
    x_ndims = (
        tf.rank(x) if tensorshape_util.rank(x.shape) is None else
        tensorshape_util.rank(x.shape))
    event_ndims = (
        tf.size(self.event_shape_tensor())
        if tensorshape_util.rank(self.event_shape) is None else
        tensorshape_util.rank(self.event_shape))
    batch_ndims = (
        tf.size(self._batch_shape_unexpanded)
        if tensorshape_util.rank(self.batch_shape) is None else
        tensorshape_util.rank(self.batch_shape))
    expected_batch_event_ndims = batch_ndims + event_ndims

    if (isinstance(x_ndims, int) and
        isinstance(expected_batch_event_ndims, int)):
      if x_ndims < expected_batch_event_ndims:
        raise NotImplementedError(
            'Broadcasting is not supported; too few batch and event dims '
            '(expected at least {}, saw {}).'.format(
                expected_batch_event_ndims, x_ndims))
      ndims_assertion = []
    elif self.validate_args:
      ndims_assertion = [
          assert_util.assert_greater_equal(
              x_ndims,
              expected_batch_event_ndims,
              message=('Broadcasting is not supported; too few '
                       'batch and event dims.'),
              name='assert_batch_and_event_ndims_large_enough'),
      ]

    if (tensorshape_util.is_fully_defined(self.batch_shape) and
        tensorshape_util.is_fully_defined(self.event_shape)):
      expected_batch_event_shape = np.int32(
          tensorshape_util.concatenate(self.batch_shape, self.event_shape))
    else:
      expected_batch_event_shape = tf.concat(
          [
              self.batch_shape_tensor(),
              self.event_shape_tensor(),
          ], axis=0)

    sample_ndims = x_ndims - expected_batch_event_ndims
    if isinstance(sample_ndims, int):
      sample_ndims = max(sample_ndims, 0)
    if (isinstance(sample_ndims, int) and
        tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
      actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
    else:
      sample_ndims = tf.maximum(sample_ndims, 0)
      actual_batch_event_shape = tf.shape(x)[sample_ndims:]

    assertions = []
    if (isinstance(expected_batch_event_shape, np.ndarray) and
        isinstance(actual_batch_event_shape, np.ndarray)):
      if any(expected_batch_event_shape != actual_batch_event_shape):
        raise NotImplementedError('Broadcasting is not supported; '
                                  'unexpected batch and event shape '
                                  '(expected {}, saw {}).'.format(
                                      expected_batch_event_shape,
                                      actual_batch_event_shape))
      # We need to set the final runtime-assertions to `ndims_assertion` since
      # its possible this assertion was created. We could add a condition to
      # only do so if `self.validate_args == True`, however this is redundant
      # as `ndims_assertion` already encodes this information.
      assertions.extend(ndims_assertion)
    elif self.validate_args:
      # We need to make the `ndims_assertion` a control dep because otherwise
      # TF itself might raise an exception owing to this assertion being
      # ill-defined, ie, one cannot even compare different rank Tensors.
      with tf.control_dependencies(ndims_assertion):
        shape_assertion = assert_util.assert_equal(
            expected_batch_event_shape,
            actual_batch_event_shape,
            message=('Broadcasting is not supported; '
                     'unexpected batch and event shape.'),
            name='assert_batch_and_event_shape_same')
      assertions.append(shape_assertion)

    return assertions
Exemple #28
0
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (tensorshape_util.dims(transition_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (tensorshape_util.dims(observation_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (tensorshape_util.dims(transition_distribution.batch_shape)
                     is not None and tensorshape_util.as_list(
                         transition_distribution.batch_shape)[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (tensorshape_util.dims(
                        observation_distribution.batch_shape) is not None
                     and tensorshape_util.as_list(
                         observation_distribution.batch_shape)[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Exemple #29
0
def _replace_event_shape_in_shape_tensor(
    input_shape, event_shape_in, event_shape_out, validate_args):
  """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
  output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
      tensorshape_util.constant_value_as_shape(input_shape),
      event_shape_in,
      event_shape_out)

  # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
  # correctly supports control_dependencies.
  validation_dependencies = (
      map(tf.identity, (event_shape_in, event_shape_out))
      if validate_args else ())

  if (tensorshape_util.is_fully_defined(output_tensorshape) and
      (is_validated or not validate_args)):
    with tf.control_dependencies(validation_dependencies):
      output_shape = tf.convert_to_tensor(
          output_tensorshape, name='output_shape', dtype_hint=tf.int32)
    return output_shape, output_tensorshape

  with tf.control_dependencies(validation_dependencies):
    event_shape_in_ndims = (
        tf.size(event_shape_in)
        if tensorshape_util.num_elements(event_shape_in.shape) is None else
        tensorshape_util.num_elements(event_shape_in.shape))
    input_non_event_shape, input_event_shape = tf.split(
        input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

  additional_assertions = []
  if is_validated:
    pass
  elif validate_args:
    # Check that `input_event_shape` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    mask = event_shape_in >= 0
    explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask)
    explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
    additional_assertions.append(
        assert_util.assert_equal(
            explicit_input_event_shape,
            explicit_event_shape_in,
            message='Input `event_shape` does not match `event_shape_in`.'))
    # We don't explicitly additionally verify
    # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
    # already makes this assertion.

  with tf.control_dependencies(additional_assertions):
    output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0,
                             name='output_shape')

  return output_shape, output_tensorshape
Exemple #30
0
def custom_gradient(fx, gx, x, fx_gx_manually_stopped=False, name=None):
  """Embeds a custom gradient into a `Tensor`.

  This function works by clever application of `stop_gradient`. I.e., observe
  that:

  ```none
  h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
  ```

  is such that `h(x) == stop_gradient(f(x))` and
  `grad[h(x), x] == stop_gradient(g(x)).`

  In addition to scalar-domain/scalar-range functions, this function also
  supports tensor-domain/scalar-range functions.

  Partial Custom Gradient:

  Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy =
  None`. This is because a `Tensor` cannot have only a portion of its gradient
  stopped. To circumvent this issue, one must manually `stop_gradient` the
  relevant portions of `f`, `g`. For example see the unit-test,
  `test_works_correctly_fx_gx_manually_stopped`.

  Args:
    fx: `Tensor`. Output of function evaluated at `x`.
    gx: `Tensor` or list of `Tensor`s. Gradient of function at (each) `x`.
    x: `Tensor` or list of `Tensor`s. Args of evaluation for `f`.
    fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually
      have `stop_gradient` applied.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    fx: Floating-type `Tensor` equal to `f(x)` but which has gradient
      `stop_gradient(g(x))`.
  """
  def maybe_stop(x):
    if fx_gx_manually_stopped:
      return x
    return tf.stop_gradient(x)

  with tf.name_scope(name or 'custom_gradient'):
    fx = tf.convert_to_tensor(fx, name='fx')
    # We don't want to bother eagerly computing `gx` since we may not even need
    # it.
    with tf.control_dependencies([fx]):
      if is_list_like(x):
        x = [identity(x_, name='x') for x_ in x]
      else:
        x = [identity(x, name='x')]

      if is_list_like(gx):
        gx = [identity(gx_, dtype=fx.dtype, name='gx')
              for gx_ in gx]
      else:
        gx = [identity(gx, dtype=fx.dtype, name='gx')]

      override_grad = []
      for x_, gx_ in zip(x, gx):
        # Observe: tf.gradients(f(x), x)[i].shape == x[i].shape
        # thus we check that the user is supplying correct shapes.
        equal_shape = assert_util.assert_equal(
            tf.shape(x_),
            tf.shape(gx_),
            message='Each `x` must have the same shape as each `gx`.')
        with tf.control_dependencies([equal_shape]):
          # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to
          # write the code this way, rather than, e.g.,
          # `sum_x * stop(gx) + stop(fx - sum_x * gx)`.
          # For more discussion regarding the relevant portions of the IEEE754
          # standard, see the StackOverflow question,
          # "Is there a floating point value of x, for which x-x == 0 is false?"
          # http://stackoverflow.com/q/2686644
          zeros_like_x_ = x_ - tf.stop_gradient(x_)
          override_grad.append(tf.reduce_sum(maybe_stop(gx_) * zeros_like_x_))
      override_grad = sum(override_grad)
      override_grad /= tf.cast(
          tf.size(fx), dtype=dtype_util.base_dtype(fx.dtype))

      # Proof of correctness:
      #
      #  f(x) = x * stop[gx] + stop[fx - x * gx]
      #       = stop[fx]
      #
      #  g(x) = grad[fx]
      #       = stop[gx] + grad[stop[fx - x * gx]]
      #       = stop[gx] + 0
      #
      # Notice that when x is zero it still works:
      # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx]
      #
      # The proof is similar for the tensor-domain case, except that we
      # `reduce_sum` the `stop[gx] * (x - stop[x])` then rescale by
      # `tf.size(fx)` since this reduced version is broadcast to `fx`.
      return maybe_stop(fx) + override_grad