Exemplo n.º 1
0
 def _Upd(c, x):
     if not self._cond_is_finite:
         return c
     c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x)))
     c = tf.math.logical_and(
         c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x))))
     return c
Exemplo n.º 2
0
def IsWithinBBox(points, bbox):
    """Checks if points are within a 2-d bbox.

  The function returns true if points are strictly inside the box. It also
  returns true when the points are exactly on the box edges.

  Args:
    points: a float Tensor of shape [..., 2] of points to be tested. The last
      coordinates are (x, y).
    bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates
      are the four corners of the bbox and (x, y). The corners are assumed to be
      given in counter-clockwise order.

  Returns:
    Tensor: If ``pshape = tf.shape(points)[:-1]`` and
    ``bshape = tf.shape(bbox)[:-2]``, returns a boolean tensor of shape
    ``tf.concat(pshape, bshape)``, where each element is true if the point is
    inside to the corresponding box.  If a point falls exactly on an edge of the
    bbox, it is also true.
  """
    bshape = py_utils.GetShape(bbox)[:-2]
    pshape = py_utils.GetShape(points)[:-1]
    bbox = py_utils.HasShape(bbox, tf.concat([bshape, [4, 2]], axis=0))
    points = py_utils.HasShape(points, tf.concat([pshape, [2]], axis=0))
    # Enumerate all 4 edges:
    v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[...,
                                                             2, :], bbox[...,
                                                                         3, :])
    v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3))
    v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4))
    v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2))
    v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1))
    with tf.control_dependencies([
            py_utils.Assert(v1v2v3_check, [v1, v2, v3]),
            py_utils.Assert(v2v3v4_check, [v3, v3, v4]),
            py_utils.Assert(v4v1v2_check, [v4, v1, v2]),
            py_utils.Assert(v3v4v1_check, [v3, v4, v1])
    ]):
        is_inside = tf.math.logical_and(
            tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v1, v2),
                                _IsOnLeftHandSideOrOn(points, v2, v3)),
            tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v3, v4),
                                _IsOnLeftHandSideOrOn(points, v4, v1)))
    has_non_zero_area = tf.greater(_BBoxArea(bbox), 0)
    is_inside = tf.logical_and(tf.cast(is_inside, tf.bool), has_non_zero_area)
    # Swap the last two dimensions.
    is_inside = tf.einsum('...ij->...ji', tf.cast(is_inside, tf.int32))
    return tf.cast(is_inside, tf.bool)
Exemplo n.º 3
0
def IsWithinBBox(points, bbox):
  """Checks if points are within a 2-d bbox.

  The function returns true if points are strictly inside the box. It also
  returns true when the points are exactly on the box edges.

  Args:
    points: a float Tensor of shape [..., 2] of points to be tested. The last
      coordinates are (x, y).
    bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates
      are the four corners of the bbox and (x, y). The corners are assumed to be
      given in counter-clockwise order.

  Returns:
    If pshape = tf.shape(points)[:-1] and bshape = tf.shape(bbox)[:-2],
    a tensor of shape tf.concat(pshape, bshape), of booleans, where
    each element is true if the point is inside to the corresponding box.
    If a point falls exactly on an edge of the bbox, it is also true.
  """
  bshape = py_utils.GetShape(bbox)[:-2]
  pshape = py_utils.GetShape(points)[:-1]
  bbox = py_utils.HasShape(bbox, bshape + [4, 2])
  points = py_utils.HasShape(points, pshape + [2])
  # Enumerate all 4 edges:
  v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[..., 2, :],
                    bbox[..., 3, :])
  v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3))
  v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4))
  v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2))
  v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1))
  with tf.control_dependencies([
      py_utils.Assert(v1v2v3_check, [v1, v2, v3]),
      py_utils.Assert(v2v3v4_check, [v3, v3, v4]),
      py_utils.Assert(v4v1v2_check, [v4, v1, v2]),
      py_utils.Assert(v3v4v1_check, [v3, v4, v1])
  ]):
    is_inside = tf.logical_and(
        tf.logical_and(
            _IsOnLeftHandSideOrOn(points, v1, v2),
            _IsOnLeftHandSideOrOn(points, v2, v3)),
        tf.logical_and(
            _IsOnLeftHandSideOrOn(points, v3, v4),
            _IsOnLeftHandSideOrOn(points, v4, v1)))
  return is_inside
 def get_accuracy(self, loss, pred, target):
   p = self.params
   int_dtype = pred.dtype
   target = tf.cast(target, int_dtype)
   pad_id = int(p.input.feature_neighborhood_input.batch_opts.pad_value)
   mask = tf.cast(tf.math.not_equal(target, pad_id), int_dtype)
   pred *= mask
   num_non_zero = tf.cast(tf.reduce_sum(mask), tf.float32)
   equal = tf.math.equal(pred, target)
   loss["accuracy_per_example"] = (tf.reduce_mean(
       tf.cast(tf.reduce_all(equal, axis=1), tf.float32)), p.input.batch_size)
   equal = tf.cast(equal, tf.float32)
   equal *= tf.cast(mask, tf.float32)
   loss["accuracy_per_char"] = (tf.reduce_sum(equal) / num_non_zero,
                                p.input.batch_size)
Exemplo n.º 5
0
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)
            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent

            def NoApplyBias():
                """No-op. Return original log_probs and consistent."""
                return bs_results.log_probs, states.consistent

            log_probs, consistent = tf.cond(
                tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias, ApplyBias)
            bs_results.log_probs = log_probs
            out_states.consistent = consistent

            return bs_results, out_states
Exemplo n.º 6
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Exemplo n.º 7
0
 def _Upd(c, k, x):
     stats[k] = x
     is_finite_checks.append(tf.reduce_all(tf.math.is_finite(x)))
     return c
Exemplo n.º 8
0
 def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                  unused_hyp_lens, done_hyps, unused_other_states_list):
     return tf.logical_and(cur_step < max_steps,
                           tf.logical_not(tf.reduce_all(done_hyps)))
Exemplo n.º 9
0
        def Callback(theta, encoder_outputs, step_ids, states,
                     num_hyps_per_beam, *args, **kwargs):
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            if biased:
                labels = encoder_outputs.targets.labels
                weights = encoder_outputs.targets.weights

                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent

                def NoApplyBias():
                    """No-op. Return original log_probs and consistent."""
                    return bs_results.log_probs, states.consistent

                log_probs, consistent = tf.cond(
                    tf.reduce_all(tf.equal(weights, 0.0)), NoApplyBias,
                    ApplyBias)
                bs_results.log_probs = log_probs
                out_states.consistent = consistent

            if stochastic:
                log_probs = bs_results.log_probs

                def PerturbedLogProbs():
                    # STEP 1: Perform top-k filtering. This is done as a performance
                    # optimization of avoiding sorting the entire `log_probs`, which is
                    # prohibitively slow.
                    top_k = tf.math.top_k(log_probs, k, sorted=True)
                    # shape: [tgt_batch, k]
                    top_k_log_probs = top_k.values
                    # shape: [tgt_batch, k]
                    top_k_ids = top_k.indices

                    # STEP 2: Perform top-p filtering.
                    # shape: [tgt_batch]
                    top_p_threshold = encoder_outputs.stochastic_beam_search.top_p_threshold
                    top_p_threshold = tf.clip_by_value(top_p_threshold, 0., 1.)
                    top_p_threshold = TileForBeamAndFlatten(top_p_threshold)
                    # shape: [tgt_batch, k]
                    filtered_top_k_log_probs = _KeepTopP(
                        top_k_log_probs, top_p_threshold)

                    # STEP 3: Perturb cumulative log-probs.
                    # shape: [tgt_batch, 1]
                    last_cumulative_log_probs = states.cumulative_log_probs
                    # shape: [tgt_batch, 1]
                    last_perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs
                    # Compute cumulative log-probs of the current step.
                    # shape: [tgt_batch, k]
                    cumulative_log_probs = (last_cumulative_log_probs +
                                            filtered_top_k_log_probs)
                    # Perturb cumulative log-probs by Gumbel noises under the condition
                    # that the max of the new perturbed log-probs is equal to
                    # perturbed_cumulative_log_probs of the previous step.
                    # shape: [tgt_batch, k]
                    new_perturbed_cumulative_log_probs = _SampleGumbelWithMax(
                        cumulative_log_probs,
                        last_perturbed_cumulative_log_probs,
                        encoder_outputs.stochastic_beam_search.seed, time_step,
                        encoder_outputs.stochastic_beam_search.src_ids,
                        encoder_outputs.stochastic_beam_search.src_paddings)

                    # STEP 4: Compute updated log_probs. This step is necessary because
                    # the output of PreBeamSearchStepCallback must be "per-step"
                    # log-probs, whereas so far "cumulative" log-probs have been computed.
                    # shape: [tgt_batch, k]
                    updated_top_k_log_probs = (
                        new_perturbed_cumulative_log_probs -
                        last_perturbed_cumulative_log_probs)
                    # Convert to the shape [tgt_batch, vocab_size].
                    updated_log_probs = tf.fill(
                        tf.shape(log_probs),
                        tf.constant(LARGE_NEGATIVE_NUMBER,
                                    dtype=log_probs.dtype))
                    updated_log_probs = _BatchScatter(updated_log_probs,
                                                      top_k_ids,
                                                      updated_top_k_log_probs)

                    return (updated_log_probs,
                            py_utils.NestedMap(
                                new_perturbed_cumulative_log_probs=
                                new_perturbed_cumulative_log_probs,
                                top_k_log_probs=top_k_log_probs,
                                top_k_ids=top_k_ids,
                            ))

                (bs_results.log_probs, out_states.tmp_states) = tf.cond(
                    encoder_outputs.stochastic_beam_search.enable,
                    PerturbedLogProbs,
                    # No-op.
                    lambda: (bs_results.log_probs, states.tmp_states))
                # These states are not updated here but will be updated in
                # PostBeamSearchStepCallback since doing so requires the knowledge of
                # the next step IDs.
                out_states.cumulative_log_probs = states.cumulative_log_probs
                out_states.perturbed_cumulative_log_probs = states.perturbed_cumulative_log_probs

            return bs_results, out_states
Exemplo n.º 10
0
 def StopFn(recurrent_theta, state, inputs):
     del recurrent_theta, inputs
     return tf.logical_not(
         tf.reduce_all(tf.equal(state.ids, p.target_eos_id)))
Exemplo n.º 11
0
def max_assignment(score: tf.Tensor,
                   *,
                   elementwise_upper_bound: tf.Tensor,
                   row_sums: tf.Tensor,
                   col_sums: tf.Tensor,
                   epsilon: float = 0.1,
                   num_iterations: int = 50,
                   use_epsilon_scaling: bool = True):
    """Differentiable max assignment with margin and upper bound constraints.

  Args:
    score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k]
      denotes the weight if the assignment on this entry is non-zero.
    elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows,
      n_columns]. Each entry denotes the maximum value assignment[i, j, k] can
      take and must be a non-negative value. For example, upper_bound[i, j,
      k]=1.0 for binary assignment problem.
    row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint.
      The output assignment p[i, j, :] must sum to row_sums[i, j].
    col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum
      constraint. The output assignment p[i, :, k] must sum to col_sums[i, k].
    epsilon: the epsilon coefficient of entropy regularization. The value should
      be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may
      not make the assignment close enough to 0 or 1.
    num_iterations: the maximum number of iterations to perform.
    use_epsilon_scaling: whether to use epsilon scaling. In practice, the
      convergence of the iterative algorithm is much better if we start by
      solving the optimization with a larger epsilon value and re-use the
      solution (i.e. dual variables) for the instance with a smaller epsilon.
      This is called the epsilon scaling trick. See [Schmitzer 2019]
      (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if
      use_epsilon_scaling=True, after each iteration we decrease the running
      epsilon by a constant factor until it reaches the target epsilon
      value. We found this to work well for gradient backward propagation,
      while the original scaling trick doesn't.

  Returns:
    A tuple with the following values.
      - assignment: a 3D tensor of size [batch_size, n_rows, n_columns].
        The output assignment.
      - used_iter: a scalar tensor indicating the number of iterations used.
      - eps: a scalar tensor indicating the stopping epsilon value.
      - delta: a scalar tensor indicating the stopping delta value (the relative
        change on the margins of assignment p in the last iteration).
  """

    # Check if all shapes are correct
    score_shape = score.shape
    bsz = score_shape[0]
    n = score_shape[1]
    m = score_shape[2]
    score = tf.ensure_shape(score, [bsz, n, m])
    elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound,
                                              [bsz, n, m])
    row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1])
    col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m])

    # the total sum of row sums must be equal to total sum of column sums
    sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums,
                                                               axis=2)
    sum_diff = tf.abs(sum_diff)
    tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff])

    # Convert upper_bound constraint into another margin constraint
    # by adding auxiliary variables & scores. Tensor `a`, `b` and `c`
    # represent the margins (i.e. reduced sum) of 3 axes respectively.
    #
    max_row_sums = tf.reduce_sum(elementwise_upper_bound,
                                 axis=-1,
                                 keepdims=True)
    max_col_sums = tf.reduce_sum(elementwise_upper_bound,
                                 axis=-2,
                                 keepdims=True)
    score_ = tf.stack([score, tf.zeros_like(score)], axis=1)  # (bsz, 2, n, m)
    a = tf.stack([row_sums, max_row_sums - row_sums], axis=1)  # (bsz, 2, n, 1)
    b = tf.stack([col_sums, max_col_sums - col_sums], axis=1)  # (bsz, 2, 1, m)
    c = tf.expand_dims(elementwise_upper_bound, axis=1)  # (bsz, 1, n, m)

    # Clip log(0) to a large negative values -1e+36 to avoid
    # getting inf or NaN values in computation. Cannot use larger
    # values because float32 would use `-inf` automatically.
    #
    tf.Assert(tf.reduce_all(a >= 0), [a])
    tf.Assert(tf.reduce_all(b >= 0), [b])
    tf.Assert(tf.reduce_all(c >= 0), [c])
    log_a = tf.maximum(tf.math.log(a), -1e+36)
    log_b = tf.maximum(tf.math.log(b), -1e+36)
    log_c = tf.maximum(tf.math.log(c), -1e+36)

    # Initialize the dual variables of margin constraints
    u = tf.zeros_like(a)
    v = tf.zeros_like(b)
    w = tf.zeros_like(c)

    eps = tf.constant(1.0 if use_epsilon_scaling else epsilon,
                      dtype=score.dtype)
    epsilon = tf.constant(epsilon, dtype=score.dtype)

    def do_updates(cur_iter, eps, u, v, w):  # pylint: disable=unused-argument
        # Epsilon scaling, i.e. gradually decreasing `eps` until it
        # reaches the target `epsilon` value
        cur_iter = tf.cast(cur_iter, u.dtype)
        scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85)
        eps = tf.maximum(epsilon, eps * scaling)
        score_div_eps = score_ / eps

        # Update u
        log_q_1 = score_div_eps + (w + v) / eps
        log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True)
        new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps

        # Update v
        log_q_2 = score_div_eps + (w + new_u) / eps
        log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True)
        new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps

        # Update w
        log_q_3 = score_div_eps + (new_u + new_v) / eps
        log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True)
        new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps
        return eps, new_u, new_v, new_w

    def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w):
        prev_sum_uvw = tf.stop_gradient((u + v + w) / eps)
        sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps)

        # Compute the relative changes on margins of P.
        # This will be used for stopping criteria.
        # Note the last update on w would guarantee the
        # margin constraint c is satisfied, so we don't
        # need to check it here.
        p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw))
        p_a = tf.reduce_sum(p, axis=-1, keepdims=True)
        p_b = tf.reduce_sum(p, axis=-2, keepdims=True)
        delta_a = tf.abs(a - p_a) / (a + 1e-6)
        delta_b = tf.abs(b - p_b) / (b + 1e-6)
        new_delta = tf.reduce_max(delta_a)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b))

        # Compute the relative changes on assignment solution P.
        # This will be used for stopping criteria.
        delta_p = tf.abs(tf.exp(prev_sum_uvw) -
                         tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p))
        return new_delta

    for cur_iter in tf.range(num_iterations):
        prev_eps, prev_u, prev_v, prev_w = eps, u, v, w
        eps, u, v, w = do_updates(cur_iter, eps, u, v, w)
    delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u,
                                     v, w)
    cur_iter = num_iterations
    assignment = tf.exp((score_ + u + v + w) / eps)
    assignment = assignment[:, 0]
    return assignment, cur_iter, eps, delta
Exemplo n.º 12
0
  def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None):
    # pyformat: disable
    """Interpolates source ids in input_batch and interpolation_batch.

    Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060.
    It is a standard Transformer Encoder if interpolation_batch != None.

    Args:
      theta: A `.NestedMap` object containing weights values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
      interpolation_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task ids
          of shape [batch, time].
        - embs: Embeddings of ids.
      lambdas: A pair of tensors to combine embeddings of ids in input_batch and
        interpolation_batch.

    Returns:
      A NestedMap of

        - encoded: The encoded features, either a tensor of shape
          [time, batch, depth], or a list of tensors if is_transparent is set in
          transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
          (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
          positional encodings.
    """
    # pyformat: enable

    p = self.params
    with tf.name_scope(p.name):
      src_segment_id = None
      src_segment_pos = None
      input_ids = py_utils.with_dependencies([
          py_utils.assert_shape_match(
              tf.shape(input_batch.ids), tf.shape(input_batch.paddings)),
          py_utils.assert_equal(tf.rank(input_batch.ids), 2)
      ], input_batch.ids)

      max_seq_length = None
      if (not py_utils.use_tpu() and
          FLAGS.transformer_encoder_truncates_inputs):
        max_seq_length = tf.cast(
            tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)),
            tf.int32)
        paddings = py_utils.with_dependencies([
            py_utils.assert_equal(
                tf.constant(True, tf.bool),
                tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5))
        ], input_batch.paddings)
        input_ids = input_ids[:, :max_seq_length]
        paddings = paddings[:, :max_seq_length]
        if p.packed_input:
          src_segment_id = input_batch.segment_ids[:, :max_seq_length]
          src_segment_pos = input_batch.segment_pos[:, :max_seq_length]
      else:
        paddings = input_batch.paddings
        if p.packed_input:
          src_segment_id = input_batch.segment_ids
          src_segment_pos = input_batch.segment_pos

      max_time = tf.shape(input_ids)[1]

      # Input token embeddings + positional embeddings
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                              tf.reshape(input_ids, [-1]))
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax,
                                            tf.reshape(input_ids, [-1]))

      if interpolation_batch is not None:
        other_input_ids = interpolation_batch.ids
        if not p.shared_emb:
          other_input_embs = self.token_emb.EmbLookup(
              theta.token_emb, tf.reshape(other_input_ids, [-1]))
        else:
          other_input_embs = self.softmax.EmbLookup(
              theta.softmax, tf.reshape(other_input_ids, [-1]))
        lambdas = [tf.expand_dims(a, -1) for a in lambdas]
        if 'embs' in input_batch and input_batch.embs is not None:
          input_embs = input_batch.embs
        if 'embs' in interpolation_batch and interpolation_batch.embs is not None:
          other_input_embs = interpolation_batch.embs
        else:
          input_embs = tf.reshape(
              input_embs,
              [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim])
          other_input_embs = tf.reshape(
              other_input_embs,
              [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim])
        input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs
        paddings = paddings + interpolation_batch.paddings - 1.0
        paddings = tf.clip_by_value(paddings, 0.0, 1.0)

      input_embs = tf.reshape(input_embs,
                              [-1, max_time, p.token_emb.embedding_dim])

      orig_input_embs = input_embs
      if p.task_emb:
        if interpolation_batch is None:
          input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                input_batch.task_ids)
        else:
          task_embs = self.task_emb.EmbLookup(theta.task_emb,
                                              input_batch.task_ids)
          other_task_embs = self.task_emb.EmbLookup(
              theta.task_emb, interpolation_batch.task_ids)
          task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs
          input_embs += task_embs

      if p.packed_input:
        position_embs = self.position_emb.FPropWithPosition(
            theta.position_emb, src_segment_pos)
      else:
        position_embs = self.position_emb.FProp(theta.position_emb, max_time)
        position_embs = tf.reshape(position_embs,
                                   [1, max_time, p.token_emb.embedding_dim])
      input_embs += position_embs

      if p.model_dim != p.token_emb.embedding_dim:
        input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

      paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
      if p.packed_input:
        src_segment_id = tf.transpose(src_segment_id)

      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)

      # [time, batch, dim]
      transformer_input = tf.transpose(input_embs, [1, 0, 2])

    if not self.do_eval and p.apply_source_mask:
      # Augment padding for masked source word positions.
      dtype = paddings.dtype
      source_mask = tf.where(
          tf.equal(input_ids, p.source_mask_id),
          tf.ones_like(input_ids, dtype=dtype),
          tf.zeros_like(input_ids, dtype=dtype))
      # Make sure padding is between 0 and 1.
      paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0,
                                  1.0)

    encoded, padding, segment_id = self.transformer_stack.FProp(
        theta.transformer_stack, transformer_input, paddings, src_segment_id)

    return py_utils.NestedMap(
        encoded=encoded,
        padding=padding,
        segment_id=segment_id,
        embedded_inputs=orig_input_embs)