def train(self, sentences):
    token_ids, token_values, token_dense_shape = self._tokenize(sentences)
    tokens_sparse = tf.sparse.SparseTensor(
        indices=token_ids, values=token_values, dense_shape=token_dense_shape)
    tokens = tf.sparse.to_dense(tokens_sparse, default_value="")

    sparse_lookup_ids = tf.sparse.SparseTensor(
        indices=tokens_sparse.indices,
        values=self._words_to_indices(tokens_sparse.values),
        dense_shape=tokens_sparse.dense_shape)
    lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0)

    # Targets are the next word for each word of the sentence.
    tokens_ids_seq = lookup_ids[:, 0:-1]
    tokens_ids_target = lookup_ids[:, 1:]

    tokens_prefix = tokens[:, 0:-1]

    # Mask determining which positions we care about for a loss: all positions
    # that have a valid non-terminal token.
    mask = tf.logical_and(
        tf.logical_not(tf.equal(tokens_prefix, "")),
        tf.logical_not(tf.equal(tokens_prefix, "<E>")))

    input_mask = tf.cast(mask, tf.int32)

    with tf.GradientTape() as t:
      sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,
                                                   tokens_ids_seq)

      lstm_initial_state = self._lstm_cell.get_initial_state(
          sentence_embeddings)

      lstm_output = self._rnn_layer(
          inputs=sentence_embeddings, initial_state=lstm_initial_state)

      # Stack LSTM outputs into a batch instead of a 2D array.
      lstm_output = tf.reshape(lstm_output, [-1, self._lstm_cell.output_size])

      logits = self._logit_layer(lstm_output)

      targets = tf.reshape(tokens_ids_target, [-1])
      weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32)

      losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=targets, logits=logits)

      # Final loss is the mean loss for all token losses.
      final_loss = tf.math.divide(
          tf.reduce_sum(tf.multiply(losses, weights)),
          tf.reduce_sum(weights),
          name="final_loss")

    watched = t.watched_variables()
    gradients = t.gradient(final_loss, watched)

    for w, g in zip(watched, gradients):
      w.assign_sub(g)

    return final_loss
Example #2
0
 def averaged_sum_squares(input_tensor):
     num_elements_cast = tf.cast(num_elements,
                                 dtype=dtype_util.real_dtype(
                                     input_tensor.dtype))
     return tf.reduce_sum(abs_square(input_tensor)) / num_elements_cast
 def regularizer(t):
     out = tfd.LogNormal(
         0., 1.).log_prob(1e-5 + tf.nn.softplus(c + t[Ellipsis, -1]))
     return -tf.reduce_sum(out) / num_updates
Example #4
0
 def true_log_joint(loc, x):
     log_prob = tf.reduce_sum(
         tfd.Normal(loc=0., scale=1.).log_prob(loc))
     log_prob += tf.reduce_sum(
         tfd.Normal(loc=loc, scale=0.5).log_prob(x))
     return log_prob
 def _sum_event_part(x):
   event_axes = ps.range(batch_ndims, ps.rank(x))
   return tf.reduce_sum(x, axis=event_axes)
Example #6
0
  def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
    """Internal function to map over.

    Consumes a batch of input examples and produces a variable number of output
    examples.

    Args:
      x: a single example
    Returns:
      a tf.data.Dataset
    """
    keys = list(feature_lengths)
    partial = empty_example.copy()
    first_key, *_ = keys
    dynamic_batch_size = tf.shape(x[first_key])[0]
    outputs = {}
    for k in keys:
      outputs[k] = tf.TensorArray(
          tf.int32, size=0, dynamic_size=True,
          element_shape=[feature_lengths[k]])
      outputs[k + "_position"] = tf.TensorArray(
          tf.int32, size=0, dynamic_size=True,
          element_shape=[feature_lengths[k]])

    for i in tf.range(0, dynamic_batch_size):
      tf.autograph.experimental.set_loop_options(
          shape_invariants=[
              (partial, {k: tf.TensorShape([None]) for k in keys_etc}),
              (outputs, {k: tf.TensorShape(None) for k in keys_etc})]
      )

      can_append = True
      one_example = {}
      for k in keys:
        val = tf.cast(x[k][i], tf.int32)
        val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
        one_example[k] = val
      for k in keys:
        can_append = tf.logical_and(
            can_append,
            tf.less_equal(
                tf.size(partial[k]) + tf.size(one_example[k]),
                feature_lengths[k]))

      if not can_append:
        partial, outputs = _write_packed_example(partial, outputs)

      new_partial = {}
      for k in keys:
        new_seq = one_example[k][:feature_lengths[k]]
        new_seq_len = tf.size(new_seq)
        new_partial[k] = tf.concat([partial[k], new_seq], 0)
        new_partial[k + "_position"] = tf.concat(
            [partial[k + "_position"],
             tf.range(new_seq_len, dtype=tf.int32)], 0)
      partial = new_partial

    partial, outputs = _write_packed_example(partial, outputs)
    packed = {k: outputs[k].stack() for k in keys_etc}
    for k in keys:
      packed[k + "_segment_id"] = (
          tf.cumsum(
              tf.cast(tf.equal(packed[k + "_position"], 0), tf.int32), axis=1) *
          tf.cast(tf.not_equal(packed[k], 0), tf.int32))
    return packed
 def inverse_event_shape_tensor(self, y):
   return tf.reduce_sum(self._flat_size_splits)[..., tf.newaxis]
  def test_transform_parts_to_vector(self, split_sizes):
    batch_shape = [4, 2]

    # Create a joint distribution with parts of the specified sizes.
    seed = test_util.test_seed_stream()
    component_dists = tf.nest.map_structure(
        lambda size: tfd.MultivariateNormalDiag(  # pylint: disable=g-long-lambda
            loc=tf.random.normal(batch_shape + [size], seed=seed()),
            scale_diag=tf.exp(
                tf.random.normal(batch_shape + [size], seed=seed()))),
        split_sizes)
    if isinstance(split_sizes, dict):
      base_dist = tfd.JointDistributionNamed(component_dists)
    else:
      base_dist = tfd.JointDistributionSequential(component_dists)

    # Transform to a vector-valued distribution by concatenating the parts.
    bijector = tfb.Invert(ToySplit(split_sizes))

    with self.assertRaisesRegexp(ValueError, 'Overriding the batch shape'):
      tfd.TransformedDistribution(base_dist, bijector, batch_shape=[3])

    with self.assertRaisesRegexp(ValueError, 'Overriding the event shape'):
      tfd.TransformedDistribution(base_dist, bijector, event_shape=[3])

    concat_dist = tfd.TransformedDistribution(base_dist, bijector)

    concat_event_size = self.evaluate(
        tf.reduce_sum(tf.nest.flatten(split_sizes)))
    self.assertAllEqual(concat_dist.event_shape, [concat_event_size])
    self.assertAllEqual(self.evaluate(concat_dist.event_shape_tensor()),
                        [concat_event_size])
    self.assertAllEqual(concat_dist.batch_shape, batch_shape)
    self.assertAllEqual(self.evaluate(concat_dist.batch_shape_tensor()),
                        batch_shape)

    # Since the Split bijector has (constant) unit Jacobian, the transformed
    # entropy and mean/mode should match the base entropy and (split) base
    # mean/mode.
    self.assertAllEqual(*self.evaluate(
        (base_dist.entropy(), concat_dist.entropy())))

    self.assertAllEqual(*self.evaluate(
        (concat_dist.mean(), bijector.forward(base_dist.mean()))))
    self.assertAllEqual(*self.evaluate(
        (concat_dist.mode(), bijector.forward(base_dist.mode()))))

    # Since the Split bijector has zero Jacobian, the transformed `log_prob`
    # and `prob` should match the base distribution.
    sample_shape = [3]
    x = base_dist.sample(sample_shape, seed=seed())
    y = bijector.forward(x)
    for attr in ('log_prob', 'prob'):
      base_attr = getattr(base_dist, attr)(x)
      concat_attr = getattr(concat_dist, attr)(y)
      self.assertAllClose(*self.evaluate((base_attr, concat_attr)))

    # Test that `.sample()` works and returns a result of the expected structure
    # and shape.
    y_sampled = concat_dist.sample(sample_shape, seed=seed())
    self.assertAllEqual(y.shape, y_sampled.shape)
Example #9
0
def _kl_independent(a, b, name='kl_independent'):
  """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default 'kl_independent'.

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
  p = a.distribution
  q = b.distribution

  # The KL between any two (non)-batched distributions is a scalar.
  # Given that the KL between two factored distributions is the sum, i.e.
  # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
  # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
  if (tensorshape_util.is_fully_defined(a.event_shape) and
      tensorshape_util.is_fully_defined(b.event_shape)):
    if a.event_shape == b.event_shape:
      if p.event_shape == q.event_shape:
        num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                           tensorshape_util.rank(p.event_shape))
        reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

        return tf.reduce_sum(
            kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
      else:
        raise NotImplementedError('KL between Independents with different '
                                  'event shapes not supported.')
    else:
      raise ValueError('Event shapes do not match.')
  else:
    p_event_shape_tensor = p.event_shape_tensor()
    q_event_shape_tensor = q.event_shape_tensor()
    # NOTE: We could optimize by passing the event_shape_tensor of p and q
    # to a.event_shape_tensor() and b.event_shape_tensor().
    a_event_shape_tensor = a.event_shape_tensor()
    b_event_shape_tensor = b.event_shape_tensor()
    with tf.control_dependencies(
        [
            assert_util.assert_equal(
                a_event_shape_tensor, b_event_shape_tensor,
                message='Event shapes do not match.'),
            assert_util.assert_equal(
                p_event_shape_tensor, q_event_shape_tensor,
                message='Event shapes do not match.'),
        ]):
      num_reduce_dims = (
          ps.rank_from_shape(
              a_event_shape_tensor, a.event_shape) -
          ps.rank_from_shape(
              p_event_shape_tensor, p.event_shape))
      reduce_dims = ps.range(-num_reduce_dims, 0, 1)
      return tf.reduce_sum(
          kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
Example #10
0
def reduce_weighted_logsumexp(logx,
                              w=None,
                              axis=None,
                              keep_dims=False,
                              return_sign=False,
                              name=None):
    """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.

  If all weights `w` are known to be positive, it is more efficient to directly
  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more
  efficient than `du.reduce_weighted_logsumexp(logx, w)`.

  Reduces `input_tensor` along the dimensions given in `axis`.
  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
  entry in `axis`. If `keep_dims` is true, the reduced dimensions
  are retained with length 1.

  If `axis` has no entries, all dimensions are reduced, and a
  tensor with a single element is returned.

  This function is more numerically stable than log(sum(w * exp(input))). It
  avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  For example:

  ```python
  x = tf.constant([[0., 0, 0],
                   [0, 0, 0]])

  w = tf.constant([[-1., 1, 1],
                   [1, 1, 1]])

  du.reduce_weighted_logsumexp(x, w)
  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)

  du.reduce_weighted_logsumexp(x, w, axis=0)
  # ==> [log(-1+1), log(1+1), log(1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1)
  # ==> [log(-1+1+1), log(1+1+1)]

  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
  # ==> [[log(-1+1+1)], [log(1+1+1)]]

  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
  # ==> log(-1+5)
  ```

  Args:
    logx: The tensor to reduce. Should have numeric type.
    w: The weight tensor. Should have numeric type identical to `logx`.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keep_dims: If true, retains reduced dimensions with length 1.
    return_sign: If `True`, returns the sign of the result.
    name: A name for the operation (optional).

  Returns:
    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
    sign: (Optional) The sign of `sum(weight * exp(x))`.
  """
    with tf.name_scope(name or 'reduce_weighted_logsumexp'):
        logx = tf.convert_to_tensor(logx, name='logx')
        if w is None:
            lswe = tf.reduce_logsumexp(input_tensor=logx,
                                       axis=axis,
                                       keepdims=keep_dims)
            if return_sign:
                sgn = tf.ones_like(lswe)
                return lswe, sgn
            return lswe
        w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w')
        log_absw_x = logx + tf.math.log(tf.abs(w))
        max_log_absw_x = tf.reduce_max(input_tensor=log_absw_x,
                                       axis=axis,
                                       keepdims=True)
        # If the largest element is `-inf` or `inf` then we don't bother subtracting
        # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
        # this is ok follows from the fact that we're actually free to subtract any
        # value we like, so long as we add it back after taking the `log(sum(...))`.
        max_log_absw_x = tf.where(tf.math.is_inf(max_log_absw_x),
                                  tf.zeros([], max_log_absw_x.dtype),
                                  max_log_absw_x)
        wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x))
        sum_wx_over_max_absw_x = tf.reduce_sum(input_tensor=wx_over_max_absw_x,
                                               axis=axis,
                                               keepdims=keep_dims)
        if not keep_dims:
            max_log_absw_x = tf.squeeze(max_log_absw_x, axis)
        sgn = tf.sign(sum_wx_over_max_absw_x)
        lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x)
        if return_sign:
            return lswe, sgn
        return lswe
Example #11
0
def convolution_batch(x,
                      kernel,
                      rank,
                      strides,
                      padding,
                      data_format=None,
                      dilations=None,
                      name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    if rank != 2:
        raise NotImplementedError(
            'Argument `rank` currently only supports `2`; '
            'saw "{}".'.format(rank))
    if data_format is not None and data_format.upper() != 'NHWBC':
        raise ValueError(
            'Argument `data_format` currently only supports "NHWBC"; '
            'saw "{}".'.format(data_format))
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            data_format,
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_tuple_argument(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        x_shape = prefer_static.shape(x)
        x_shape_ = x.shape
        x = tf.reshape(
            x,  # [n, h, w, b, c]
            shape=prefer_static.pad(x_shape[:-2],
                                    paddings=[[0, 1]],
                                    constant_values=-1))  # [n, h, w, bc]

        kernel_shape = prefer_static.shape(kernel)  # [b, fh, fw, c, c']
        kernel_shape_ = kernel.shape
        kernel = tf.transpose(kernel, [1, 2, 0, 3, 4])
        kernel = tf.reshape(
            kernel,  # [fh, fw, b, c, c']
            shape=prefer_static.concat([
                kernel_shape[1:-2],
                [-1, kernel_shape[-1]],
            ],
                                       axis=0))  # [fh, fw, bc, c']

        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               prefer_static.shape(y)[:-1],
                               kernel_shape[:1],
                               kernel_shape[-2:],
                           ],
                           axis=0))  # [n, h, w, b, c, c']
        y = tf.reduce_sum(y, axis=-2)  # [n, h, w, b, c']
        tensorshape_util.set_shape(
            y.shape,
            tensorshape_util.concatenate(x_shape_[:-1], kernel_shape_[-1]))
        return y
 def proposal_log_prob(x):
     counter['proposal_calls'] += 1
     event_dims = ps.range(independent_chain_ndims, ps.rank(x))
     return tf.reduce_sum(tfd.Normal(loc=0., scale=1.).log_prob(x),
                          axis=event_dims)
 def proposal_log_prob(x):
     event_dims = ps.range(independent_chain_ndims, ps.rank(x))
     return -0.5 * tf.reduce_sum(x**2. + np.log(2 * np.pi),
                                 axis=event_dims)
Example #14
0
 def _log_unnormalized_prob(self, counts):
     counts = self._maybe_assert_valid_sample(counts)
     return tf.reduce_sum(input_tensor=counts *
                          tf.nn.log_softmax(self.logits),
                          axis=-1)
Example #15
0
def _bfgs_inv_hessian_update(grad_delta, position_delta, normalization_factor,
                             inv_hessian_estimate):
    """Applies the BFGS update to the inverse Hessian estimate.

  The BFGS update rule is (note A^T denotes the transpose of a vector/matrix A).

  ```None
    rho = 1/(grad_delta^T * position_delta)
    U = (I - rho * position_delta * grad_delta^T)
    H_1 =  U * H_0 * U^T + rho * position_delta * position_delta^T
  ```

  Here, `H_0` is the inverse Hessian estimate at the previous iteration and
  `H_1` is the next estimate. Note that `*` should be interpreted as the
  matrix multiplication (with the understanding that matrix multiplication for
  scalars is usual multiplication and for matrix with vector is the action of
  the matrix on the vector.).

  The implementation below utilizes an expanded version of the above formula
  to avoid the matrix multiplications that would be needed otherwise. By
  expansion it is easy to see that one only needs matrix-vector or
  vector-vector operations. The expanded version is:

  ```None
    f = 1 + rho * (grad_delta^T * H_0 * grad_delta)
    H_1 - H_0 = - rho * [position_delta * (H_0 * grad_delta)^T +
                        (H_0 * grad_delta) * position_delta^T] +
                  rho * f * [position_delta * position_delta^T]
  ```

  All the terms in square brackets are matrices and are constructed using
  vector outer products. All the other terms on the right hand side are scalars.
  Also worth noting that the first and second lines are both rank 1 updates
  applied to the current inverse Hessian estimate.

  Args:
    grad_delta: Real `Tensor` of shape `[..., n]`. The difference between the
      gradient at the new position and the old position.
    position_delta: Real `Tensor` of shape `[..., n]`. The change in position
      from the previous iteration to the current one.
    normalization_factor: Real `Tensor` of shape `[...]`. Should be equal to
      `grad_delta^T * position_delta`, i.e. `1/rho` as defined above.
    inv_hessian_estimate: Real `Tensor` of shape `[..., n, n]`. The previous
      estimate of the inverse Hessian. Should be positive definite and
      symmetric.

  Returns:
    A tuple containing the following fields
      is_valid: A Boolean `Tensor` of shape `[...]` indicating batch members
        where the update succeeded. The update can fail if the position change
        becomes orthogonal to the gradient change.
      next_inv_hessian_estimate: A `Tensor` of shape `[..., n, n]`. The next
        Hessian estimate updated using the BFGS update scheme. If the
        `inv_hessian_estimate` is symmetric and positive definite, the
        `next_inv_hessian_estimate` is guaranteed to satisfy the same
        conditions.
  """
    # The quadratic form: y^T.H.y; where H is the inverse Hessian and y is the
    # gradient change.
    conditioned_grad_delta = tf.linalg.matvec(inv_hessian_estimate, grad_delta)
    conditioned_grad_delta_norm = tf.reduce_sum(conditioned_grad_delta *
                                                grad_delta,
                                                axis=-1)

    # The first rank 1 update term requires the outer product: s.y^T.
    cross_term = _tensor_product(position_delta, conditioned_grad_delta)

    def _expand_scalar(s):
        # Expand dimensions of a batch of scalars to multiply or divide a matrix.
        return s[..., tf.newaxis, tf.newaxis]

    # Symmetrize
    cross_term += _tensor_product(conditioned_grad_delta, position_delta)
    position_term = _tensor_product(position_delta, position_delta)
    with tf.control_dependencies([position_term]):
        position_term *= _expand_scalar(1 + conditioned_grad_delta_norm /
                                        normalization_factor)

    return (
        inv_hessian_estimate +
        (position_term - cross_term) / _expand_scalar(normalization_factor))
Example #16
0
def contrastive_loss(features,
                     labels=None,
                     temperature=1.0,
                     contrast_mode=enums.LossContrastMode.ALL_VIEWS,
                     summation_location=enums.LossSummationLocation.OUTSIDE,
                     denominator_mode=enums.LossDenominatorMode.ALL,
                     positives_cap=-1,
                     scale_by_temperature=True):
    r"""Contrastive loss over features.

  Implemented as described in: https://arxiv.org/abs/2004.11362, Equation 2.

  Given `num_views` different views of each of `batch_size` samples, let `f_i`
  (i \in [1, 2 ... (num_views * batch_size)]) denote each respective feature
  vector. The contrastive loss then takes the following form:

    L = \sum_{i} L_i

  where each L_i is computed as:

    L_i = -\tau * \sum_{k \in P(i)} \log(p_{ik})    (1)

  where P(i) is the set of positives for entry i (distinct from i) and where:

                       \exp(f_i^T f_k / \tau)
    p_{ik} = ----------------------------------------                        (2)
             \sum_{j \in A(i)} \exp(f_i^T f_j / \tau)

  where A(i) is the set of all positives or negatives (distinct from i). `i` is
  the anchor, and \tau is the temperature.

  This maximizes the likelihood of a given (anchor, positive) pair with
  respect to all possible pairs where the first member is the anchor and the
  second member is a positive or a negative.

  A typical way to define a positive is to define samples from the
  same class (but not the anchor itself) regardless of what view they are from.
  Similarly, a typical way to define a negative is for it to be any view of a
  sample from a different class.

  There are two ways to define which feature pairs should be treated as
  positives and negatives. All views of the same sample are always treated as
  positives. You can declare other samples to be positives by providing `labels`
  such that all samples with the same label will be positives for each other.

  If `labels` is not provided then we default to every sample belonging to its
  own unique class. Therefore, the only positive used is another view of the
  anchor itself. This implements the loss as described in:

    https://arxiv.org/pdf/2002.05709.pdf
    A Simple Framework for Contrastive Learning of Visual Representations
    Chen T., Kornblith S., Norouzi M., Hinton G.

  It is recommended to use features whose L_2 norm is 1. since that ensures
  that the loss does not return NaN values without changing the intended
  behaviour of the loss function.

  In (1) above, note that the summation over positives is located outside of the
  \log(). However, one can permute these two operations. The result is Eq. 3 in
  https://arxiv.org/abs/2004.11362. Users can specify the location of the
  summation relative to the \log() via the `summation_location' argmument:
   - 'out': Eq. 2 in https://arxiv.org/abs/2004.11362.
   - 'in' : Eq. 3 in https://arxiv.org/abs/2004.11362.

  Additionally, in (2) above, note that the denominator sums over *all* entries
  distinct from i. One can change which terms are included in the denominator
  via the `denominator_mode` argument:
   - LossDenominatorMode.ALL : All entries (i.e., all negatives and all
             positives) distinct from i are included.
   - LossDenominatorMode.ONE_POSITIVE : All negatives are included but only the
             single positive in the numerator of (2) is included. Any other
             positives are excluded.
   - LossDenominatorMode.ONLY_NEGATIVES: All negatives are included but no
             positives are, not even the single positive in the numerator of
             (2).

  On TPUs, this method will internally perform the cross-replica operations that
  enable using the samples from all cores in computing the loss. The inputs to
  this function should be the features and labels from a single core and each
  core will compute the loss using just these features as anchors, but will use
  positives and negatives from the full global batch. Since the loss for each
  anchor is only computed on one TPU core, it's still necessary to have a
  cross-replica reduction in the final loss computation.

  Also, though it is not applicable to multiview contrastive learning, this
  function will work if |features| contains only 1 view. In the high batch size
  limit, the implemented contrastive loss with only 1 view, positives_cap = 1,
  and temperature = 1.0 is equivalent to the N-pairs loss
  (https://papers.nips.cc/paper/6200-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective.pdf)

  Args:
    features: A Tensor of rank at least 3, where the first 2 dimensions are
      batch_size and num_views, and the remaining dimensions are the feature
      shape. Note that when running on TPU, batch_size is the per-core batch
      size.
    labels: One-hot labels to be used to construct the supervised contrastive
      loss. Samples with the same labels are used as positives for each other.
      Labels must have shape [batch_size, num_labels] with numeric dtype and be
      0-1 valued. Note that when running on TPU, batch_size is the per-core
      batch size.
    temperature: Temperature at which softmax evaluation is done. Temperature
      must be a python scalar or scalar Tensor of numeric dtype.
    contrast_mode: LossContrastMode specifying which views get used as anchors
      (f_i in the expression above)
      'ALL_VIEWS': All the views of all samples are used as anchors (f_i in the
        expression above).
      'ONE_VIEW': Just the first view of each sample is used as an anchor (f_i
        in the expression above). This view is called the `core` view against
        which other views are contrasted.
    summation_location: LossSummationLocation specifying location of positives
      summation. See documentation above for more details.
    denominator_mode: LossDenominatorMode specifying which positives to include
      in contrastive denominator. See documentation above for more details.
    positives_cap: Integer maximum number of positives *other* than
      augmentations of anchor. Infinite if < 0. Must be multiple of num_views.
      Including augmentations, a maximum of (positives_cap + num_views - 1)
      positives is possible. This parameter modifies the contrastive numerator
      by selecting which positives are present in the summation, and which
      positives contribure to the denominator if denominator_mode ==
      enums.LossDenominatorMode.ALL.
    scale_by_temperature: Boolean. Whether to scale the loss by `temperature`.
      The loss gradient naturally has a 1/temperature scaling factor, so this
      counteracts it.

  Returns:
    Scalar tensor with contrastive loss value with shape [batch_size] and dtype
    tf.float32. The loss for each batch element is the mean over all views.

  Raises:
    ValueError if the shapes of any of the Tensors are unexpected, or if both
    `labels` and `mask` are not `None`.
  """
    features = tf.convert_to_tensor(features)
    labels = tf.convert_to_tensor(labels) if labels is not None else None

    local_batch_size, num_views = _validate_contrastive_loss_inputs(
        features, labels, contrast_mode, summation_location, denominator_mode,
        positives_cap)

    # Flatten `features` to a single dimension per view per sample so it has shape
    # [local_batch_size, num_views, num_features].
    if features.shape.rank > 3:
        features = tf.reshape(
            features, tf.concat([tf.shape(features)[:2], [-1]], axis=0),
            'flattened_features')
    if features.dtype != tf.float32:
        features = tf.cast(features, tf.float32)

    # Grab the features from all TPU cores. We use the local batch as anchors and
    # the full global batch as contrastives. If not on TPU, global_features is the
    # same as features.
    global_features = utils.cross_replica_concat(features)
    global_batch_size = tf.compat.dimension_at_index(global_features.shape,
                                                     0).value
    local_replica_id = utils.local_tpu_replica_id()

    # Generate the [local_batch_size, global_batch_size] slice of the
    # [global_batch_size, global_batch_size] identity matrix that corresponds to
    # the current replica.
    diagonal_mask = tf.one_hot(
        tf.range(local_batch_size) + (local_replica_id * local_batch_size),
        global_batch_size)

    # Generate `mask` with shape [local_batch_size, global_batch_size] that
    # indicates which samples should be considered positives for each other.
    if labels is None:
        # Defaults to every sample belonging to its own unique class, containing
        # just that sample and other views of it.
        mask = diagonal_mask
    else:
        labels = tf.cast(labels,
                         tf.float32)  # TPU matmul op unsupported for ints.
        global_labels = utils.cross_replica_concat(labels)
        mask = tf.linalg.matmul(labels, global_labels, transpose_b=True)
    mask = tf.ensure_shape(mask, [local_batch_size, global_batch_size])

    # To streamline the subsequent TF, the first two dimensions of
    # `global_features` (i.e., global_batch_size and num_views) should be
    # transposed and then flattened. The result has shape
    # [num_views * global_batch_size, num_features], and its first dimension
    # elements are grouped by view, not by sample.
    all_global_features = tf.reshape(
        tf.transpose(global_features, perm=[1, 0, 2]),
        [num_views * global_batch_size, -1])

    if contrast_mode == enums.LossContrastMode.ONE_VIEW:
        anchor_features = features[:, 0]
        num_anchor_views = 1
    else:  # contrast_mode == enums.LossContrastMode.ALL_VIEWS
        # Reshape features to match how global_features is reshaped above.
        anchor_features = tf.reshape(tf.transpose(features, perm=[1, 0, 2]),
                                     [num_views * local_batch_size, -1])
        num_anchor_views = num_views

    # Generate `logits`, the tensor of (temperature-scaled) dot products of the
    # anchor features with all features. It has shape
    # [local_batch_size * num_anchor_views, global_batch_size * num_views]. To
    # improve numerical stability, subtract out the largest |logits| element in
    # each row from all elements in that row. Since |logits| is only ever used as
    # a ratio of exponentials of |logits| values, this subtraction does not change
    # the results correctness. A stop_gradient() is needed because this change is
    # just for numerical precision.
    logits = tf.linalg.matmul(anchor_features,
                              all_global_features,
                              transpose_b=True)
    temperature = tf.cast(temperature, tf.float32)
    logits = logits / temperature
    logits = (logits -
              tf.reduce_max(tf.stop_gradient(logits), axis=1, keepdims=True))
    exp_logits = tf.exp(logits)

    # The following masks are all tiled by the number of views, i.e., they have
    # shape [local_batch_size * num_anchor_views, global_batch_size * num_views].
    positives_mask, negatives_mask = (_create_tiled_masks(
        mask, diagonal_mask, num_views, num_anchor_views, positives_cap))
    num_positives_per_row = tf.reduce_sum(positives_mask, axis=1)

    if denominator_mode == enums.LossDenominatorMode.ALL:
        denominator = tf.reduce_sum(
            exp_logits * negatives_mask, axis=1,
            keepdims=True) + tf.reduce_sum(
                exp_logits * positives_mask, axis=1, keepdims=True)
    elif denominator_mode == enums.LossDenominatorMode.ONE_POSITIVE:
        denominator = exp_logits + tf.reduce_sum(
            exp_logits * negatives_mask, axis=1, keepdims=True)
    else:  # denominator_mode == enums.LossDenominatorMode.ONLY_NEGATIVES
        denominator = tf.reduce_sum(exp_logits * negatives_mask,
                                    axis=1,
                                    keepdims=True)

    # Note that num_positives_per_row can be zero only if 1 view is used. The
    # various tf.math.divide_no_nan() calls below are to handle this case.
    if summation_location == enums.LossSummationLocation.OUTSIDE:
        log_probs = (logits - tf.math.log(denominator)) * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
    else:  # summation_location == enums.LossSummationLocation.INSIDE
        log_probs = exp_logits / denominator * positives_mask
        log_probs = tf.reduce_sum(log_probs, axis=1)
        log_probs = tf.math.divide_no_nan(log_probs, num_positives_per_row)
        log_probs = tf.math.log(log_probs)

    loss = -log_probs
    if scale_by_temperature:
        loss *= temperature
    loss = tf.reshape(loss, [num_anchor_views, local_batch_size])

    if num_views != 1:
        loss = tf.reduce_mean(loss, axis=0)
    else:
        # The 1 view case requires special handling bc, unlike in the > 1 view case,
        # not all samples are guaranteed to have a positive. Also, no reduction over
        # views is needed.
        num_valid_views_per_sample = (tf.reshape(num_positives_per_row,
                                                 [1, local_batch_size]))
        loss = tf.squeeze(
            tf.math.divide_no_nan(loss, num_valid_views_per_sample))

    return loss
Example #17
0
def EffectiveSampleSize(states,
                        filter_beyond_lag=300,
                        filter_threshold=0.05,
                        use_geyer=False,
                        center=True,
                        normalize=True):
    """ESS computation for one single Tensor argument."""
    def _axis_size(x, axis=None):
        """Get number of elements of `x` in `axis`, as type `x.dtype`."""
        if axis is None:
            return tf.cast(tf.size(x), x.dtype)
        return tf.cast(tf.reduce_prod(tf.gather(tf.shape(x), axis)), x.dtype)

    with tf.name_scope("effective_sample_size_single_state"):

        states = tf.convert_to_tensor(states, name="states")
        dt = states.dtype

        # filter_beyond_lag == None ==> auto_corr is the full sequence.
        auto_corr = SanitizedAutoCorrelationMean(states,
                                                 axis=0,
                                                 reduce_axis=1,
                                                 center=center,
                                                 normalize=normalize,
                                                 max_lags=filter_beyond_lag)
        orig_auto_corr = auto_corr
        if use_geyer:

            def _sum_pairs(x):
                if x.shape[0] % 2 != 0:
                    x = tf.concat(
                        [x, tf.zeros(tf.concat([[1], tf.shape(x)[1:]], 0))], 0)
                return tf.reduce_sum(
                    tf.reshape(x, [tf.shape(x)[0] // 2, 2, -1]), 1)

            def _make_pairs(x):
                return tf.reshape(tf.tile(x[:, tf.newaxis, :], [1, 2, 1]),
                                  [-1, x.shape[-1]])

            auto_corr_pairs = _make_pairs(
                _sum_pairs(auto_corr))[:auto_corr.shape[0]]
            mask = auto_corr_pairs < 0.
            mask = tf.cast(mask, dt)
            mask = tf.cumsum(mask, axis=0)
            mask = tf.maximum(1. - mask, 0.)
            auto_corr *= mask
        elif filter_threshold is not None:
            filter_threshold = tf.convert_to_tensor(filter_threshold,
                                                    dtype=dt,
                                                    name="filter_threshold")
            # Get a binary mask to zero out values of auto_corr below the threshold.
            #   mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i,
            #   mask[i, ...] = 0, otherwise.
            # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...]
            # Building step by step,
            #   Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2.
            # Step 1:  mask = [False, False, True, False]
            mask = tf.abs(auto_corr) < filter_threshold
            # Step 2:  mask = [0, 0, 1, 1]
            mask = tf.cast(mask, dtype=dt)
            # Step 3:  mask = [0, 0, 1, 2]
            mask = tf.cumsum(mask, axis=0)
            # Step 4:  mask = [1, 1, 0, 0]
            mask = tf.maximum(1. - mask, 0.)
            auto_corr *= mask

        # With R[k] := auto_corr[k, ...],
        # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]}
        #     = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1)
        #     approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]}
        # where M is the filter_beyond_lag truncation point chosen above.

        # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total
        # ndims the same as auto_corr
        n = _axis_size(states, axis=0)
        k = tf.range(0., _axis_size(auto_corr, axis=0))
        nk_factor = (n - k) / n
        if auto_corr.shape.ndims is not None:
            new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1)
        else:
            new_shape = tf.concat(
                ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)),
                axis=0)
        nk_factor = tf.reshape(nk_factor, new_shape)

        # return tf.reduce_mean(n / (
        #   -1 + 2 * tf.reduce_sum(nk_factor * auto_corr, axis=0)), 0)
        # return n / (1.0 + 2 *
        #             tf.reduce_sum(nk_factor[1:, ...] * auto_corr[1:, ...],
        #             axis=0))
        # return tf.reduce_mean(n / (-auto_corr[0] + 2 *
        #   tf.reduce_sum(nk_factor * auto_corr, axis=0)), 0)
        # print(auto_corr[0])
        return n / (orig_auto_corr[0] + 2 * tf.reduce_sum(
            nk_factor[1:, Ellipsis] * auto_corr[1:, Ellipsis], axis=0))
  def test_transform_vector_to_parts(self, split_sizes):
    batch_shape = [4, 2]
    base_event_size = tf.reduce_sum(tf.nest.flatten(split_sizes))
    base_dist = tfd.MultivariateNormalDiag(
        loc=tf.random.normal(
            batch_shape + [base_event_size], seed=test_util.test_seed()),
        scale_diag=tf.exp(tf.random.normal(
            batch_shape + [base_event_size], seed=test_util.test_seed())))

    bijector = ToySplit(split_sizes)
    split_dist = tfd.TransformedDistribution(base_dist, bijector)

    expected_event_shape = tf.nest.map_structure(
        lambda s: np.array([s]), split_sizes)
    output_event_shape = nest.map_structure_up_to(
        split_dist.dtype, np.array, split_dist.event_shape)
    self.assertAllEqualNested(output_event_shape, expected_event_shape)
    self.assertAllEqualNested(self.evaluate(split_dist.event_shape_tensor()),
                              expected_event_shape)
    self.assertAllEqual(split_dist.batch_shape, batch_shape)
    self.assertAllEqual(self.evaluate(split_dist.batch_shape_tensor()),
                        batch_shape)

    # Since the Split bijector has (constant) unit Jacobian, the transformed
    # entropy and mean/mode should match the base entropy and (split) base
    # mean/mode.
    self.assertAllEqual(*self.evaluate(
        (base_dist.entropy(), split_dist.entropy())))
    self.assertAllEqualNested(
        *self.evaluate((split_dist.mean(),
                        bijector.forward(base_dist.mean()))))
    self.assertAllEqualNested(
        *self.evaluate((split_dist.mode(),
                        bijector.forward(base_dist.mode()))))

    # Since the Split bijector has zero Jacobian, the transformed `log_prob`
    # and `prob` should match the base distribution.
    sample_shape = [3]
    x = base_dist.sample(sample_shape, seed=test_util.test_seed())
    y = bijector.forward(x)
    for attr in ('log_prob', 'prob'):
      split_attr = getattr(split_dist, attr)(y)
      base_attr = getattr(base_dist, attr)(x)
      self.assertAllClose(*self.evaluate((base_attr, split_attr)), rtol=1e-5)

    # Test that `.sample()` works and returns a result of the expected structure
    # and shape.
    y_sampled = split_dist.sample(sample_shape, seed=test_util.test_seed())
    self.assertAllEqualNested(
        tf.nest.map_structure(lambda x: x.shape, y),
        tf.nest.map_structure(lambda x: x.shape, y_sampled))

    # Test that `batch_shape` override works and does not affect the event shape
    base_dist = tfd.Independent(
        tfd.Normal(loc=list(range(6)), scale=1.),
        reinterpreted_batch_ndims=1, validate_args=True)
    override_batch_shape = [5, 2]
    split_dist_batch_override = tfd.TransformedDistribution(
        base_dist, bijector, batch_shape=override_batch_shape)
    self.assertAllEqualNested(
        split_dist_batch_override.event_shape, expected_event_shape)
    self.assertAllEqualNested(
        self.evaluate(split_dist_batch_override.event_shape_tensor()),
        expected_event_shape)
    self.assertAllEqual(split_dist_batch_override.batch_shape,
                        override_batch_shape)
    self.assertAllEqual(
        self.evaluate(split_dist_batch_override.batch_shape_tensor()),
        override_batch_shape)

    # Test that `event_shape` override works as expected with `Split`
    override_event_shape = [6]
    base_dist = tfd.Normal(0., [2., 1.])
    split_dist_event_override = tfd.TransformedDistribution(
        base_dist, bijector, event_shape=override_event_shape)
    self.assertAllEqualNested(
        split_dist_event_override.event_shape, expected_event_shape)
    self.assertAllEqualNested(
        self.evaluate(split_dist_event_override.event_shape_tensor()),
        expected_event_shape)
    self.assertAllEqual(
        split_dist_event_override.batch_shape, base_dist.batch_shape)
    self.assertAllEqual(
        self.evaluate(split_dist_event_override.batch_shape_tensor()),
        self.evaluate(base_dist.batch_shape_tensor()))
Example #19
0
 def _sum_pairs(x):
     if x.shape[0] % 2 != 0:
         x = tf.concat(
             [x, tf.zeros(tf.concat([[1], tf.shape(x)[1:]], 0))], 0)
     return tf.reduce_sum(
         tf.reshape(x, [tf.shape(x)[0] // 2, 2, -1]), 1)
Example #20
0
def _compute_log_acceptance_correction(current_state_parts,
                                       proposed_state_parts,
                                       current_volatility_parts,
                                       proposed_volatility_parts,
                                       current_drift_parts,
                                       proposed_drift_parts,
                                       step_size_parts,
                                       independent_chain_ndims,
                                       name=None):
    r"""Helper to `kernel` which computes the log acceptance-correction.

  Computes `log_acceptance_correction` as described in `MetropolisHastings`
  class. The proposal density is normal. More specifically,

   ```none
  q(proposed_state | current_state) \sim N(current_state + current_drift,
  step_size * current_volatility**2)

  q(current_state | proposed_state) \sim N(proposed_state + proposed_drift,
  step_size * proposed_volatility**2)
  ```

  The `log_acceptance_correction` is then

  ```none
  log_acceptance_correctio = q(current_state | proposed_state)
  - q(proposed_state | current_state)
  ```

  Args:
    current_state_parts: Python `list` of `Tensor`s representing the value(s) of
      the current state of the chain.
    proposed_state_parts:  Python `list` of `Tensor`s representing the value(s)
      of the proposed state of the chain. Must broadcast with the shape of
      `current_state_parts`.
    current_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*current_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_volatility_parts: Python `list` of `Tensor`s representing the value
      of `volatility_fn(*proposed_volatility_parts)`. Must broadcast with the
      shape of `current_state_parts`
    current_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*current_state_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    proposed_drift_parts: Python `list` of `Tensor`s representing value of the
      drift `_get_drift(*proposed_drift_parts, ..)`. Must broadcast with the
      shape of `current_state_parts`.
    step_size_parts: Python `list` of `Tensor`s representing the step size for
      Euler-Maruyama method. Must broadcast with the shape of
      `current_state_parts`.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """

    with tf.name_scope(name or 'compute_log_acceptance_correction'):

        proposed_log_density_parts = []
        dual_log_density_parts = []

        for [
                current_state,
                proposed_state,
                current_volatility,
                proposed_volatility,
                current_drift,
                proposed_drift,
                step_size,
        ] in zip(
                current_state_parts,
                proposed_state_parts,
                current_volatility_parts,
                proposed_volatility_parts,
                current_drift_parts,
                proposed_drift_parts,
                step_size_parts,
        ):
            axis = tf.range(independent_chain_ndims, tf.rank(current_state))

            state_diff = proposed_state - current_state

            current_volatility *= tf.sqrt(step_size)

            proposed_energy = (state_diff - current_drift) / current_volatility

            proposed_volatility *= tf.sqrt(step_size)
            # Compute part of `q(proposed_state | current_state)`
            proposed_energy = (tf.reduce_sum(mcmc_util.safe_sum(
                [tf.math.log(current_volatility), 0.5 * (proposed_energy**2)]),
                                             axis=axis))
            proposed_log_density_parts.append(-proposed_energy)

            # Compute part of `q(current_state | proposed_state)`
            dual_energy = (state_diff + proposed_drift) / proposed_volatility
            dual_energy = (tf.reduce_sum(mcmc_util.safe_sum(
                [tf.math.log(proposed_volatility), 0.5 * (dual_energy**2)]),
                                         axis=axis))
            dual_log_density_parts.append(-dual_energy)

        # Compute `q(proposed_state | current_state)`
        proposed_log_density_reduce = tf.add_n(proposed_log_density_parts)
        # Compute `q(current_state | proposed_state)`
        dual_log_density_reduce = tf.add_n(dual_log_density_parts)

        return mcmc_util.safe_sum(
            [dual_log_density_reduce, -proposed_log_density_reduce])
Example #21
0
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = ps.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = ps.rank(x) - event_ndims - 1
        mask_batch_ndims = ps.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = ps.shape(mask)
        mask = tf.reshape(mask,
                          shape=ps.concat([
                              mask_shape[:-1],
                              ps.ones([pad_ndims], dtype=tf.int32),
                              mask_shape[-1:],
                              ps.ones([event_ndims], dtype=tf.int32),
                          ],
                                          axis=0))

        if x.dtype in [
                tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64,
                tf.complex128
        ]:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
def _batchwise_reduce_sum(x):
    with tf1.name_scope("batchwise_reduce_sum"):
        return tf.reduce_sum(input_tensor=x, axis=tf.range(1, tf.rank(x)))
Example #23
0
    def _distributional_transform(self, x, event_shape):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the 'flattened' event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution
      event_shape: The event shape of this distribution

    Returns:
      Result of the distributional transform
    """

        if tensorshape_util.rank(x.shape) is None:
            # tf.math.softmax raises an error when applied to inputs of undefined
            # rank.
            raise ValueError(
                'Distributional transform does not support inputs of '
                'undefined rank.')

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message='`univariate_components` must have scalar event')
        ]):
            event_ndims = ps.rank_from_shape(event_shape)
            x_padded = self._pad_sample_dims(
                x, event_ndims=event_ndims)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            event_size = ps.cast(ps.reduce_prod(event_shape), dtype=tf.int32)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, event_size]),
                    exclusive=True,
                    axis=-1),
                ps.shape(log_prob_x))  # [S, B, k, E]

            event_ndims = ps.rank_from_shape(event_shape)
            logits_mix_prob = self.mixture_distribution.logits_parameter()
            logits_mix_prob = tf.reshape(
                logits_mix_prob,  # [k] or [B, k]
                ps.concat([
                    ps.shape(logits_mix_prob),
                    ps.ones([event_ndims], dtype=tf.int32),
                ],
                          axis=0))  # [k, [1]*e] or [B, k, [1]*e]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = tensorshape_util.rank(x.shape) - event_ndims
            posterior_weights_x = tf.math.softmax(log_posterior_weights_x,
                                                  axis=component_axis)
            return tf.reduce_sum(posterior_weights_x * cdf_x,
                                 axis=component_axis)
def _update_trajectory_grad(previous_kernel_results, previous_state,
                            proposed_state, proposed_velocity,
                            trajectory_jitter, accept_prob, step_size,
                            criterion_fn, max_leapfrog_steps):
  """Updates the trajectory length."""
  # Compute criterion grads.
  def leapfrog_action(dt):
    # This represents the effect on the criterion value as the state follows the
    # proposed velocity. This implicitly assumes an identity mass matrix.
    return criterion_fn(
        previous_state,
        tf.nest.map_structure(
            lambda x, v:  # pylint: disable=g-long-lambda
            (x + mcmc_util.left_justified_expand_dims_like(dt, v) * v),
            proposed_state,
            proposed_velocity),
        accept_prob)

  criterion, trajectory_grad = gradient.value_and_gradient(
      leapfrog_action, tf.zeros_like(accept_prob))
  trajectory_grad *= trajectory_jitter

  # Weight by acceptance probability.
  trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
  trajectory_grad = tf.where(
      tf.math.is_finite(trajectory_grad), trajectory_grad, 0.)
  trajectory_grad = (
      tf.reduce_sum(trajectory_grad * accept_prob) /
      tf.reduce_sum(accept_prob + 1e-20))

  # Compute Adam/RMSProp step size.
  dtype = previous_kernel_results.adaptation_rate.dtype
  iteration_f = tf.cast(previous_kernel_results.step, dtype) + 1.
  msg_adaptation_rate = 0.05
  new_averaged_sq_grad = (
      (1 - msg_adaptation_rate) * previous_kernel_results.averaged_sq_grad +
      msg_adaptation_rate * trajectory_grad**2)
  adjusted_averaged_sq_grad = new_averaged_sq_grad / (
      1. - (1 - msg_adaptation_rate)**iteration_f)
  trajectory_step_size = (
      previous_kernel_results.adaptation_rate /
      tf.sqrt(adjusted_averaged_sq_grad + 1e-20))

  # Apply the gradient. Clip absolute value to ~log(2)/2.
  log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad, -0.35,
                                0.35)
  new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp(
      log_update)

  # Iterate averaging.
  average_weight = iteration_f**(-0.5)
  new_averaged_max_trajectory_length = tf.exp(
      average_weight * tf.math.log(new_max_trajectory_length) +
      (1 - average_weight) *
      tf.math.log(1e-10 +
                  previous_kernel_results.averaged_max_trajectory_length))

  # Clip the maximum trajectory length.
  new_max_trajectory_length = _clip_max_trajectory_length(
      new_max_trajectory_length, step_size,
      previous_kernel_results.adaptation_rate, max_leapfrog_steps)

  return previous_kernel_results._replace(
      criterion=criterion,
      max_trajectory_length=new_max_trajectory_length,
      averaged_sq_grad=new_averaged_sq_grad,
      averaged_max_trajectory_length=new_averaged_max_trajectory_length)
Example #25
0
 def _forward_log_det_jacobian(self, x):
     with tf.control_dependencies(self._assertions(x)):
         return -tf.reduce_sum(tf.math.log(x[..., 1:] - x[..., :-1]),
                               axis=-1)
 def loss():
     return tf.reduce_sum(tf.compat.v1.nn.embedding_lookup(var0, [[1]]))
Example #27
0
 def losses(self):
     """Sum of the KL divergences between priors + posteriors"""
     return tf.reduce_sum([s.losses for s in self.steps])
Example #28
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = seed_stream.SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                stack_axis = -1 - tensorshape_util.rank(
                    self._static_event_shape)
                x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
                npdt = dtype_util.as_numpy_dtype(x.dtype)
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=npdt(1),
                    off_value=npdt(0))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    tensorshape_util.rank(
                        self._static_event_shape))  # [n, B, k, [1]*e]
                return tf.reduce_sum(input_tensor=x * mask,
                                     axis=stack_axis)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(value=n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if tensorshape_util.is_fully_defined(static_samples_shape):
                samples_shape = tensorshape_util.as_list(static_samples_shape)
                samples_size = tensorshape_util.num_elements(
                    static_samples_shape)
            else:
                samples_shape = tf.shape(input=cat_samples)
                samples_size = tf.size(input=cat_samples)
            static_batch_shape = self.batch_shape
            if tensorshape_util.is_fully_defined(static_batch_shape):
                batch_shape = tensorshape_util.as_list(static_batch_shape)
                batch_size = tensorshape_util.num_elements(static_batch_shape)
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(input_tensor=batch_shape)
            static_event_shape = self.event_shape
            if tensorshape_util.is_fully_defined(static_event_shape):
                event_shape = np.array(
                    tensorshape_util.as_list(static_event_shape),
                    dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = seed_stream.SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(input=partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            tensorshape_util.set_shape(
                ret,
                tensorshape_util.concatenate(static_samples_shape,
                                             self.event_shape))
            return ret
Example #29
0
 def losses(self):
     prior = tfd.Normal(0, 1)
     return (tf.reduce_sum(tfd.kl_divergence(self.weight, prior)) +
             tf.reduce_sum(tfd.kl_divergence(self.bias, prior)))
Example #30
0
  def _parameter_control_dependencies(self, is_init):
    """Validate parameters."""
    bw, bh, kd = None, None, None
    try:
      shape = tf.broadcast_static_shape(self.bin_widths.shape,
                                        self.bin_heights.shape)
    except ValueError as e:
      raise ValueError('`bin_widths`, `bin_heights` must broadcast: {}'.format(
          str(e)))
    bin_sizes_shape = shape
    try:
      shape = tf.broadcast_static_shape(shape[:-1], self.knot_slopes.shape[:-1])
    except ValueError as e:
      raise ValueError(
          '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
          'batch axes: {}'.format(str(e)))

    assertions = []
    if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:]) and
        tensorshape_util.is_fully_defined(self.knot_slopes.shape[-1:])):
      if tensorshape_util.rank(self.knot_slopes.shape) > 0:
        num_interior_knots = tensorshape_util.dims(bin_sizes_shape)[-1] - 1
        if tensorshape_util.dims(
            self.knot_slopes.shape)[-1] not in (1, num_interior_knots):
          raise ValueError(
              'Innermost axis of non-scalar `knot_slopes` must broadcast with '
              '{}; got {}.'.format(num_interior_knots, self.knot_slopes.shape))
    elif self.validate_args:
      if is_init != any(
          tensor_util.is_ref(t)
          for t in (self.bin_widths, self.bin_heights, self.knot_slopes)):
        bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
        shape = tf.broadcast_dynamic_shape(
            tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
        assertions.append(
            assert_util.assert_greater(
                tf.shape(shape)[0],
                tf.zeros([], dtype=shape.dtype),
                message='`(bin_widths + bin_heights)[..., :-1]` must broadcast '
                'with `knot_slopes` to at least 1-D.'))

    if not self.validate_args:
      assert not assertions
      return assertions

    if (is_init != tensor_util.is_ref(self.bin_widths) or
        is_init != tensor_util.is_ref(self.bin_heights)):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_near(
              tf.reduce_sum(bw, axis=-1),
              tf.reduce_sum(bh, axis=-1),
              message='`sum(bin_widths, axis=-1)` must equal '
              '`sum(bin_heights, axis=-1)`.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_widths):
      bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
      assertions += [
          assert_util.assert_positive(
              bw, message='`bin_widths` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.bin_heights):
      bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
      assertions += [
          assert_util.assert_positive(
              bh, message='`bin_heights` must be positive.'),
      ]
    if is_init != tensor_util.is_ref(self.knot_slopes):
      kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
      assertions += [
          assert_util.assert_positive(
              kd, message='`knot_slopes` must be positive.'),
      ]
    return assertions
Example #31
0
 def rank_not_equal_case():
     tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)])
     weights_sum = tf.reduce_sum(weights)
     axes = tf.convert_to_tensor([[axis], [0]])
     avg = tf.tensordot(a, weights, axes) / weights_sum
     return avg, weights_sum