Beispiel #1
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     params = self.params
     warmup_steps = tf.cast(params.decay_start * params.worker_replicas,
                            tf.float32)
     current_step = tf.cast(current_step, tf.float32)
     if params.decay_end is not None:
         current_step = tf.where(current_step < params.decay_end,
                                 current_step,
                                 tf.cast(params.decay_end, tf.float32))
     peak_learning_rate = (warmup_steps**-0.5)
     return (params.model_dim**-0.5) * tf.minimum(
         tf.minimum((current_step + 1),
                    (current_step + 1)**-0.5), peak_learning_rate)
Beispiel #2
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(p.warmup_steps, tf.float32)
     linear_warmup = tf.minimum(1.0, current_step / warmup_steps)
     rsqrt_decay = tf.math.rsqrt(tf.maximum(current_step, warmup_steps))
     return p.model_dim**-0.5 * linear_warmup * rsqrt_decay
Beispiel #3
0
 def Update(self, new_value):
     state0 = self.GetValue()
     state1 = tf.stack([
         state0[0] + new_value[0],
         tf.minimum(state0[1], new_value[1]),
         tf.maximum(state0[2], new_value[2]),
     ])
     self.SetValue(state1)
Beispiel #4
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(
         p.warmup_examples / (p.batch_size * self._num_replicas),
         tf.float32)
     return tf.minimum((current_step + 1) * warmup_steps**-1.5,
                       (current_step + 1)**-0.5)
Beispiel #5
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32)
     if p.decay_end is not None:
         current_step = tf.where(current_step < p.decay_end, current_step,
                                 tf.cast(p.decay_end, tf.float32))
     return p.model_dim**-0.5 * tf.minimum(
         (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
Beispiel #6
0
 def FProp(self, theta, current_step):
     p = self.params
     assert p.total_steps > 0
     assert p.initial_value > p.final_value
     with tf.name_scope(p.name):
         decay_gap = p.initial_value - p.final_value
         return p.final_value + 0.5 * decay_gap * (1 + tf.cos(
             math.pi *
             tf.minimum(1.0,
                        tf.cast(current_step, tf.float32) / p.total_steps)))
Beispiel #7
0
 def GetTensorRange(self, t_name, ts):
     # Always straddle a real zero point.
     if self.do_eval:
         # At eval/inference time, use the memorized range.
         # Important: Don't capture these variables in training mode so as to
         # avoid extra/unnecessary captures.
         min_var = tf.stop_gradient(self._GetQStateVar(t_name, 'min'))
         max_var = tf.stop_gradient(self._GetQStateVar(t_name, 'max'))
         return (min_var, max_var)
     # Calculate min/max for all tensors.
     batch_min = tf.minimum(tf.reduce_min(ts), 0.0)
     batch_max = tf.maximum(tf.reduce_max(ts), 0.0)
     return (tf.stop_gradient(batch_min), tf.stop_gradient(batch_max))
Beispiel #8
0
 def _Value(self, current_step):
     """Returns the current clipping cap."""
     p = self.params
     start_step = tf.cast(p.start_step, tf.float32)
     end_step = tf.cast(p.end_step, tf.float32)
     current_step = tf.cast(current_step, tf.float32)
     steps_ratio = (
         tf.minimum(end_step - start_step, current_step - start_step) /
         (end_step - start_step))
     rmax_tensor = (steps_ratio * p.end_cap +
                    (1.0 - steps_ratio) * p.start_cap)
     return tf.cond(tf.less(current_step, p.start_step),
                    lambda: tf.cast(p.start_cap, tf.float32),
                    lambda: tf.cast(rmax_tensor, tf.float32))
Beispiel #9
0
    def QuantizeTensors(self, t_name, ts, eval_only=False):
        p = self.params
        # Always straddle a real zero point.
        if self.do_eval:
            # At eval/inference time, use the memorized range.
            # Important: Don't capture these variables in training mode so as to
            # avoid extra/unnecessary captures.
            min_var = self._GetQStateVar(t_name, 'min')
            max_var = self._GetQStateVar(t_name, 'max')
            return [
                self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits)
                for t in ts
            ]
        else:
            # At training time, use the batch calculated min/max.
            accumulator_name = self._GetAccumulatorNameForTensor(t_name)
            # Calculate min/max for all tensors.
            batch_min = 0.0
            batch_max = 0.0
            for t in ts:
                batch_min = tf.minimum(tf.reduce_min(t), batch_min)
                batch_max = tf.maximum(tf.reduce_max(t), batch_max)

            # New state.
            state1 = tf.stack([1.0, batch_min, batch_max])
            self.accumulators[accumulator_name].Update(state1)

            # Results.
            ts_out = []
            for i, t in enumerate(ts):
                if eval_only:
                    # If only quantizing at eval time, still record ranges as above
                    # but don't quantize.
                    quant_t = t
                else:
                    # If quantizing during training, skip quantization if it produces
                    # NANs. Sometimes early in the training process, things are unstable
                    # and ranges can produce numerical instability that makes it
                    # impossible to perform a fake_quant.
                    quant_t = self._MaybeFakeQuant(t,
                                                   batch_min,
                                                   batch_max,
                                                   num_bits=p.bits)
                    # TODO(laurenzo): Plumb quant_t_has_nans through state and report.
                    quant_t_has_nans = tf.math.is_nan(quant_t)
                    quant_t = tf.where(quant_t_has_nans, t, quant_t)
                ts_out.append(quant_t)
                summary_utils.histogram(
                    '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t)
            return ts_out
Beispiel #10
0
 def QuantizeWeight(self, w):
     p = self.params
     w_min = tf.reduce_min(w)
     w_max = tf.reduce_max(w)
     # NOTE: We force a small, non-zero range because otherwise, zero weights
     # can cause downstream inference engines to blow up.
     w_min = tf.minimum(w_min, -p.quantize_weight_epsilon)
     w_max = tf.maximum(w_max, p.quantize_weight_epsilon)
     quant_w = self._MaybeFakeQuant(w, w_min, w_max, num_bits=p.bits)
     if self.do_eval:
         return quant_w
     else:
         # If quantizing during training, skip quantization if it produces
         # NANs. Sometimes early in the training process, things are unstable
         # and ranges can produce numerical instability that makes it
         # impossible to perform a fake_quant.
         quant_w_has_nans = tf.math.is_nan(quant_w)
         return tf.where(quant_w_has_nans, w, quant_w)
Beispiel #11
0
    def GetState(self, theta):
        """Gets the state from theta."""
        p = self.params
        if p.is_inference:
            # State is not used for inference. Just return dummy.
            return tf.zeros([1], tf.float32)
        else:
            # Calculations/vars need to be float but these can be ints in the params.
            clip_end_step = tf.cast(p.clip_end_step, tf.float32)
            clip_start_step = tf.cast(p.clip_start_step, tf.float32)
            quant_start_step = tf.cast(p.quant_start_step, tf.float32)
            global_step = tf.cast(theta.global_step, tf.float32)

            # Will be negative if before clipping starts.
            clip_ratio = (tf.minimum(clip_end_step - clip_start_step,
                                     global_step - clip_start_step) /
                          tf.maximum(1.0, clip_end_step - clip_start_step))
            # Currently fq is either on (1.0) or off (-1.0). Progressive quantization
            # may later occupy 0..1.0.
            fq_ratio = tf.where(global_step < quant_start_step, -1.0, 1.0)

            return tf.stack([clip_ratio, fq_ratio])
    def _GetGlobalGradScale(self, all_grad_norm, has_nan_or_inf):
        """Returns a scaling factor for all gradients according to their norm.

    In case there are NaN or Inf values the function will return 0.0.

    Args:
      all_grad_norm: A scalar represeting the total norm of all vars.
      has_nan_or_inf: A scalar of 0 or 1, indicating whether there is any NaN or
        Inf in input gradients.

    Returns:
      The gradient scale. 0 if gradient updates should be skipped for the step.
    """
        p = self.params
        # Computes gradient's scale.
        grad_scale = tf.constant(1.0)
        if p.clip_gradient_norm_to_value:
            # If all_grad_norm > p.clip_gradient_norm_to_value, scales
            # all_grads so that the norm is 1.0.
            grad_scale = tf.minimum(
                1.0, p.clip_gradient_norm_to_value / all_grad_norm)

        if p.grad_norm_to_clip_to_zero:
            # If all_grad_norm > p.grad_norm_to_clip_to_zero, treats
            # grad_scale as 0. This way, we ignore this step.
            grad_scale *= tf.cast(all_grad_norm < p.grad_norm_to_clip_to_zero,
                                  p.dtype)

        if p.grad_norm_tracker:
            grad_scale *= self.grad_norm_tracker.FPropDefaultTheta(
                all_grad_norm, has_nan_or_inf)

        # Force grad_scale to be 0 if there is any NaN or Inf in gradients.
        grad_scale = tf.where(has_nan_or_inf, 0.0, grad_scale)

        return grad_scale
Beispiel #13
0
    def _RecordTensor(self, t_name):
        p = self.params
        if self.do_eval:
            return []

        accumulator_name = self._GetAccumulatorNameForTensor(t_name)
        accumulator = self.accumulators[accumulator_name]
        min_var = self._GetQStateVar(t_name, 'min')
        max_var = self._GetQStateVar(t_name, 'max')

        # Unpack state tensor.
        current_value = accumulator.GetValue()
        count = current_value[0]
        min_value = current_value[1]
        max_value = current_value[2]
        accumulator.Reset()

        def Ema(variable, value):
            return (1.0 - p.ema_decay) * (variable - value)

        # Note that small floating point issues can cause ranges that naturally
        # begin or end at zero to move slightly past, causing hard failures
        # downstream (checks that all ranges straddle zero). We therefore repeat
        # the straddling constraint here.
        return [
            tf.assign(
                min_var,
                tf.minimum(
                    0., min_var -
                    tf.where(count > 0., Ema(min_var, min_value), 0.))),
            tf.assign(
                max_var,
                tf.maximum(
                    0., max_var -
                    tf.where(count > 0., Ema(max_var, max_value), 0.))),
        ]
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
    def _GetMask(self,
                 batch_size,
                 choose_range,
                 mask_size,
                 global_seed,
                 max_length=None,
                 masks_per_frame=0.0,
                 multiplicity=1,
                 dtype=tf.float32,
                 max_ratio=1.0):
        """Returns fixed size multi-masks starting from random positions.

    A multi-mask is a mask obtained by applying multiple masks.

    This function when max_length is given:
      1) Sample random mask lengths less than max_length with shape
         (batch_size, multiplicity).
      2) Truncate lengths to a max of (choose_range * max_ratio),
         so that each mask is fully contained within the corresponding sequence.
      3) Random sample start points of shape (batch_size, multiplicity)
         with in (choose_range - lengths).
      4) For each batch, multiple masks (whose number is given by the
         multiplicity) are constructed.
      5) Return a mask of shape (batch_size, mask_size) where masks are
         obtained by composing the masks constructed in step 4).
         If masks_per_frame > 0, the number is given by
         min(masks_per_frame * choose_range, multiplicity).
         If not, all the masks are composed. The masked regions are set to zero.

    This function when max_length is not given:
      1) Sample random mask lengths less than (choose_range * max_ratio)
         with shape (batch_size, multiplicity).
      2) Proceed to steps 3), 4) and 5) of the above.

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the masked entries must lie. Tensor of
        shape (batch_size,).
      mask_size: Size of the mask. Integer number.
      global_seed: an integer seed tensor for stateless random ops.
      max_length: Maximum number of allowed consecutive masked entries. Integer
        number or None.
      masks_per_frame: Number of masks per frame. Float number. If > 0, the
        multiplicity of the mask is set to be masks_per_frame * choose_range.
      multiplicity: Maximum number of total masks. Integer number.
      dtype: Data type.
      max_ratio: Maximum portion of the entire range allowed to be masked. Float
        number.

    Returns:
      mask: a fixed size multi-mask starting from a random position with shape
      (batch_size, mask_size).
    """
        p = self.params
        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_1 and seed_2 are set separately to avoid
        # correlation of mask size and mask position.
        if p.use_input_dependent_random_seed:
            seed_1 = global_seed + 1
            seed_2 = global_seed + 2
        elif p.random_seed:
            seed_1 = p.random_seed + 1
            seed_2 = 2 * p.random_seed
        else:
            seed_1 = p.random_seed
            seed_2 = p.random_seed
        # Sample lengths for multiple masks.
        if max_length and max_length > 0:
            max_length = tf.broadcast_to(tf.cast(max_length, dtype),
                                         (batch_size, ))
        else:
            max_length = tf.cast(choose_range, dtype=dtype) * max_ratio
        random_uniform = _random_uniform_op(p.use_input_dependent_random_seed)
        masked_portion = random_uniform(shape=(batch_size, multiplicity),
                                        minval=0.0,
                                        maxval=1.0,
                                        dtype=dtype,
                                        seed=seed_1)
        masked_frame_size = self.EinsumBBmBm(max_length, masked_portion)
        masked_frame_size = tf.cast(masked_frame_size, dtype=tf.int32)
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way was biased
        # (shorter sequence may over-masked.)
        choose_range = tf.expand_dims(choose_range, -1)
        choose_range = tf.tile(choose_range, [1, multiplicity])
        length_bound = tf.cast(choose_range, dtype=dtype)
        length_bound = tf.cast(max_ratio * length_bound, dtype=tf.int32)
        length = tf.minimum(masked_frame_size, tf.maximum(length_bound, 1))

        # Choose starting point.
        random_start = random_uniform(shape=(batch_size, multiplicity),
                                      maxval=1.0,
                                      seed=seed_2)
        start_with_in_valid_range = random_start * tf.cast(
            (choose_range - length + 1), dtype=dtype)
        start = tf.cast(start_with_in_valid_range, tf.int32)
        end = start + length - 1

        # Shift starting and end point by small value.
        delta = tf.constant(0.1)
        start = tf.expand_dims(tf.cast(start, dtype) - delta, -1)
        start = tf.tile(start, [1, 1, mask_size])
        end = tf.expand_dims(tf.cast(end, dtype) + delta, -1)
        end = tf.tile(end, [1, 1, mask_size])

        # Construct pre-mask of shape (batch_size, multiplicity, mask_size).
        diagonal = tf.expand_dims(
            tf.expand_dims(tf.cast(tf.range(mask_size), dtype=dtype), 0), 0)
        diagonal = tf.tile(diagonal, [batch_size, multiplicity, 1])
        pre_mask = tf.cast(tf.math.logical_and(diagonal < end,
                                               diagonal > start),
                           dtype=dtype)

        # Sum masks with appropriate multiplicity.
        if masks_per_frame > 0:
            multiplicity_weights = tf.tile(
                tf.expand_dims(tf.range(multiplicity, dtype=dtype), 0),
                [batch_size, 1])
            multiplicity_tensor = masks_per_frame * tf.cast(choose_range,
                                                            dtype=dtype)
            multiplicity_weights = tf.cast(
                multiplicity_weights < multiplicity_tensor, dtype=dtype)
            pre_mask = self.EinsumBmtBmBt(pre_mask, multiplicity_weights)
        else:
            pre_mask = tf.reduce_sum(pre_mask, 1)
        mask = tf.cast(1.0 - tf.cast(pre_mask > 0, dtype=dtype), dtype=dtype)

        if p.fprop_dtype is not None and p.fprop_dtype != p.dtype:
            mask = tf.cast(mask, p.fprop_dtype)

        return mask
    def _GetWarpMatrix(self,
                       batch_size,
                       choose_range,
                       matrix_size,
                       global_seed,
                       max_warp_frames=None,
                       dtype=tf.float32,
                       max_ratio=1.0):
        """Returns warp matrices starting from random positions.

    In this function when max_warp_frames != None:
      1) Sample random warp displacements from the interval
         [-max_warp_frames, max_warp_frames) to yield shift tensor
         with shape (batch_size,).
      2) Truncate lengths to a maximum magnitude of (choose_range * max_ratio),
         so that each shift is fully contained within the
         corresponding sequence.
      3) Random sample origin points of shape (batch_size, multiplicity)
         with in [shift, choose_range - shift).
      4) Return a batch of 1-D linear maps that fix the boundary points and
         shift the origin point by the shift.

    When max_warp_frames == None:
      1) Sample random warp displacements with magnitudes less than
         (choose_range * max_ratio) to yield shift tensor with
         shape (batch_size,).
      2) Proceed through steps 3), 4).

    Args:
      batch_size: Batch size. Integer number.
      choose_range: Range within which the warp reference points must lie.
        Tensor of shape (batch_size,).
      matrix_size: Dimension of vector space warp matrix is applied to. Integer
        number.
      global_seed: an integer seed tensor for stateless random ops.
      max_warp_frames: Upper-bound on the warp distance. Integer or None.
      dtype: Data type.
      max_ratio: Maximum ratio between the shift distance and choose_range.
        Float number.

    Returns:
      warp_matrix: An array of fixed size warp matrices with shape
      (batch_size, matrix_size, matrix_size).
    """
        p = self.params
        # Non-empty random seed values are only used for testing or when using
        # stateless random ops. seed_3, seed_4, and seed_5 are set separately to
        # avoid correlation of warp magnitude and origin position.
        if p.use_input_dependent_random_seed:
            seed_3 = global_seed + 3
            seed_4 = global_seed + 4
            seed_5 = global_seed + 5
        elif p.random_seed:
            seed_3 = p.random_seed - 1
            seed_4 = p.random_seed - 1
            seed_5 = 2 * p.random_seed + 1
        else:
            seed_3 = p.random_seed
            seed_4 = p.random_seed
            seed_5 = p.random_seed

        choose_range_dtype = tf.cast(choose_range, dtype=dtype)
        length_upper_bound = tf.cast(max_ratio * choose_range_dtype,
                                     dtype=tf.int32)
        # Set shift length.

        random_uniform = _random_uniform_op(p.use_input_dependent_random_seed)

        if max_warp_frames and max_warp_frames > 0:
            shift = random_uniform(shape=(batch_size, ),
                                   minval=-1 * max_warp_frames,
                                   maxval=max_warp_frames + 1,
                                   dtype=tf.int32,
                                   seed=seed_3)
        else:
            random_ratio = random_uniform(shape=(batch_size, ),
                                          minval=-1.0,
                                          maxval=1.0,
                                          dtype=dtype,
                                          seed=seed_4)
            shift = tf.cast(
                random_ratio * tf.cast(length_upper_bound, dtype=dtype),
                tf.int32)
        # Make sure the sampled length was smaller than max_ratio * length_bound.
        # Note that sampling in this way is biased.
        # (Shorter sequence may over-masked.)
        final_shift = tf.maximum(-length_upper_bound,
                                 tf.minimum(shift, length_upper_bound))
        # Choose origin anchor point.
        mid_range = tf.cast(choose_range, dtype=tf.int32)
        mid_range = tf.maximum(choose_range - 2, 0)
        random_origin = random_uniform(shape=(batch_size, ),
                                       maxval=1.0,
                                       seed=seed_5)
        origin_with_in_valid_range = random_origin * tf.cast(mid_range,
                                                             dtype=dtype)
        origin = tf.cast(origin_with_in_valid_range, tf.int32) + 1
        # Set destination point of the origin anchor point under the warp map.
        destination = origin + final_shift
        # Cast origin and destination.
        origin = tf.cast(origin, dtype=dtype)
        destination = tf.cast(destination, dtype=dtype)

        return self._ConstructWarpMatrix(batch_size=batch_size,
                                         matrix_size=matrix_size,
                                         origin=origin,
                                         destination=destination,
                                         choose_range=choose_range_dtype,
                                         dtype=dtype)
    def _ConstructWarpMatrix(self, batch_size, matrix_size, origin,
                             destination, choose_range, dtype):
        """Returns warp matrices according to origin, destination and choose_range.

    This function constructs a batch of warp matrices which maps the batch
    of origin points to the batch of destination points with fixed boundary
    coordinates at 0 and choose_range.

    The warping function, defined by the origin anchor point `origin`,
    the destination of the origin anchor point `destination` and the
    length of the domain in the warping axis `choose_range` is a piecewise
    linear map that fixes the points 0 and `choose_range` and maps
    `origin` to `destination`.

    For the warping matrix to be non-singular, destination must lie in the
    range 1<= destination <= choose_range - 1, so a destination
    out of this range is adjusted to be in this range before the warping
    matrix is constructed.

    The warping map can be explicitly written by first defining the slopes:
      1) slope_0 = origin / destination.
      2) slope_1 = (choose_range - origin) / (choose_range - destination).
      3) slope_2 = 1.0.

    Then the origin point orig_i of the mapped coordinate i is given by:
      1) i < destination: orig_i = slope_0 * i.
      2) destination <= i < choose_range:
         orig_i = slope_1 * i - (slope_1 - slope_0) * destination.
      3) i >= choose_range: orig_i = i.

    Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
      1) j = n_i: 1 - n_i + orig_i.
      2) j = n_i - 1: n_i - orig_i.
      3) Otherwise: 0.

    Applying the warp matrix to an array of pixels, i.e.,
    warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get
    warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i].

    Args:
      batch_size: Batch size. Integer number.
      matrix_size: Dimension of the vector space the warp matrix is applied to.
        Integer number.
      origin: Origin anchor point for warping. Tensor of shape (batch_size,) and
        data type dtype.
      destination: Destination of the origin anchor point upon warping. Tensor
        of shape (batch_size,) and data type dtype.
      choose_range: Range within which the warp reference points must lie.
        Tensor of shape (batch_size,) data type dtype.
      dtype: Data type of origin, destination, choose_range and the output warp
        matrix.

    Returns:
      warp_matrix: An array of fixed size warp matrices with shape
      (batch_size, matrix_size, matrix_size).
    """
        p = self.params

        # Entries of destination must be in the range
        # 1 <= destination <= choose_range - 1
        # for warp matrix to have non-singular values.
        destination = tf.minimum(tf.maximum(destination, 1.0),
                                 choose_range - 1.0)

        # Construct piece-wise linear function fixing boundary points
        # specified by zero, choose_range and matrix size and maps
        # the origin anchor point to the destination.
        destination_bc = tf.broadcast_to(destination,
                                         (matrix_size, batch_size))
        destination_bc = tf.transpose(destination_bc)
        choose_range_bc = tf.broadcast_to(choose_range,
                                          (matrix_size, batch_size))
        choose_range_bc = tf.transpose(choose_range_bc)

        # Slopes of piece-wise linear function.
        slope_0 = origin / destination
        slope_1 = (choose_range - origin) / (choose_range - destination)
        slope_2 = 1.0

        # x is a batch of origin matrices.
        # The origin matrix is the matrix such that
        # origin[i][j] = Origin coordinate of coordinate i for the warp map.
        # Denoting the destination of the origin anchor point in the
        # warp map as "dest," the origin coordinate of point i is given by:
        # 1) i < dest: slope_0 * i.
        # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest.
        # 3) i >= choose_range: i.
        x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size))
        x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm(
            slope_1 - slope_0, tf.nn.relu(x - destination_bc)) +
             self.EinsumBBmBm(slope_2 - slope_1,
                              tf.nn.relu(x - choose_range_bc)))
        x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size))
        x = tf.transpose(x, perm=[1, 2, 0])

        # y is a batch of coordinate matrices.
        # A coordinate matrix is a matrix such that
        # coordinate[i][j] = j.
        y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype),
                            (batch_size, matrix_size, matrix_size))
        # Warp matrix is obtained by applying hat function element-wise to (x-y).
        # Denoting the origin point of i under the warp map as orig_i,
        # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by:
        # 1) j = n_i: 1 - n_i + orig_i.
        # 2) j = n_i - 1: n_i - orig_i.
        # 3) Otherwise: 0.
        # Applying the warp matrix to pixels, i.e.,
        # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get
        # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1]
        #                   + (1 - n_i + orig_i) * original_pixel[n_i].
        warp_matrix = x - y
        warp_matrix = _hat(warp_matrix)
        if p.fprop_dtype is not None and p.fprop_dtype != dtype:
            warp_matrix = tf.cast(warp_matrix, p.fprop_dtype)

        return warp_matrix
def merge_hyps(global_score_values, histories_in, mask, num_beams,
               num_hyps_per_beam):
    """Merges candidate hypotheses with identical histories.

  This function takes a set of candidate hypotheses, represented as Tensors of
  scores and histories, and merges all pairs of hypotheses that have identical
  history hashes. When two hypotheses are merged, the hyp with lower global
  score gets "deleted" and has its probability mass added to the higher scoring
  one. Hypotheses are "deleted" by giving them empty history and a large
  negative global score. The function output is a tuple of new
  (global_score_values, histories) Tensors.

  All input Tensors are assumed to be in "div" hypothesis ordering. That is,
  element [i, ...] corresponds to the j-th hyp of the n-th beam, where j = i % k
  and n = i / k.

  Example:
    Suppose num_beams = 1, num_hyps_per_beam = 2, candidates_per_hyp = 5,
    global_score_values is
      [[11 12 13 14 15],
       [17 16 10 19 20]]
    and histories_in is
      [[1 2 3 4 5],
       [5 6 3 7 8]].

    There are two pairs of hypotheses with identical histories that should
    be merged -- two with hash value 3 and two with hash 5. In each pair, the
    one with lower score will be deleted and merged into the one with higher
    score.

    The output is a new set of global_score_values,
      [[ 11     12 13.04 14 -1e34 ],
         17.13  16 -1e34 19 20    ]]
    and new histories
      [[1 2 3 4 0],
       [5 6 0 7 8]].
    Hypotheses deleted in the merge now have zero history and a large negative
    score. The destination of each merge now has additional probability mass.
    (Note _log_sum_exp(13, 10) ~= 13.04 and _log_sum_exp(15, 17) ~= 17.13.)

  Args:
    global_score_values: Tensor of shape [b * k, candidates_per_hyp], the global
      scores of each candidate hypothesis.
    histories_in: int32 Tensor of shape [b * k, candidates_per_hyp], the
      histories of each candidate hypothesis.
    mask: Tensor of shape [b * k, 1] indicating which entries in
      global_score_values and histories_in are valid.
    num_beams: int, the number of beams (b above).
    num_hyps_per_beam: int, the number of hypotheses per beam (k above).

  Returns:
    A tuple of new (global_score_values, histories) updated so that input
    hypotheses with identical histories are now merged. Hypotheses deleted in
    the merge have a new global score of BEST_SCORES_INIT and a history of 0.
  """
    values_dtype = global_score_values.dtype
    candidates_per_hyp = histories_in.get_shape()[1]
    k = num_hyps_per_beam

    # High-level strategy: To detect hyps to merge, we'll permute the hypotheses
    # within each beam so that their histories are in sorted order. We can then
    # in parallel check whether each history is equal to its left or right
    # neighbor (i.e. whether the hyps should be merged), and if so, which of them
    # has the higher global score (the direction of the merge). When two hyps need
    # to be merged, we'll "delete" the one with lower score (by giving it a large
    # negative score and empty history) and add its probability mass to the other.
    #
    # Note we only have to do pair-wise merging once per beam search step, because
    # (ignoring hash collisions) there are at most two candidate hypotheses with
    # any particular history. This follows from the fact that hypotheses are
    # unique at the start of the beam search step, as are the top K non-epsilon
    # extensions of those hypotheses. Thus, if there are two paths with
    # identical histories, they must have the form
    #   h_i <eps> == h_j s  (for some i != j, s != eps),
    # where h_i and h_j are distinct input hypotheses, and s is some non-epsilon
    # symbol.

    # Reshape inputs to [b, num_hyps_per_beam * candidates_per_hyp] so they're
    # grouped by beam.
    histories = histories_in
    orig_scores_shape = tf.shape(global_score_values)
    histories = tf.reshape(histories, [num_beams, -1])
    histories_valid = tf.cast(
        tf.reshape(tf.tile(mask, [1, candidates_per_hyp]), [num_beams, -1]),
        values_dtype)
    # Compute the permutation of hyps within each beam that put the histories in
    # sorted order, and the one that permutates the sorted hyps back to the
    # original order.
    sorted_history_indices = tf.argsort(histories, axis=1)
    inverse_indices = tf.argsort(sorted_history_indices, axis=1)

    def to_flat_indices(column_indices_per_row):
        column_indices_per_row.shape.assert_has_rank(2)
        flat_indices = (column_indices_per_row +
                        num_hyps_per_beam * candidates_per_hyp *
                        tf.reshape(tf.range(num_beams), [num_beams, 1]))
        return tf.reshape(flat_indices, [-1])

    # Convert to linear indices so we can use fast_gather.
    sorted_history_indices_flat = to_flat_indices(sorted_history_indices)
    inverse_indices_flat = to_flat_indices(inverse_indices)

    def history_sort(values):
        return tf.reshape(
            fast_gather(tf.reshape(values,
                                   [-1, 1]), sorted_history_indices_flat,
                        num_beams * k * candidates_per_hyp),
            [num_beams, k * candidates_per_hyp])

    def history_unsort(values):
        return tf.reshape(
            fast_gather(tf.reshape(values, [-1, 1]), inverse_indices_flat,
                        num_beams * k * candidates_per_hyp), orig_scores_shape)

    sorted_histories = history_sort(histories)
    sorted_histories_valid = history_sort(histories_valid)

    # Indicators of whether each hypothesis is a duplicate of its left/right
    # neighbors.
    # [num_batches, k * candidates_per_hyp - 1]
    dup_mask = tf.cast(
        tf.equal(sorted_histories[:, 1:], sorted_histories[:, :-1]),
        values_dtype) * (sorted_histories_valid[:, 1:] *
                         sorted_histories_valid[:, :-1])
    padding = tf.zeros([num_beams, 1], dtype=values_dtype)
    is_dup_of_left = tf.concat([padding, dup_mask], axis=1)
    is_dup_of_right = tf.concat([dup_mask, padding], axis=1)

    # Examine global scores to see which hyps should be merged, and within those
    # cases, which hyps get deleted/retained in the merge.
    sorted_global_scores = history_sort(global_score_values)
    # Global scores of each hyp's left and right neighbors.
    right_global_scores = tf.concat([sorted_global_scores[:, 1:], padding],
                                    axis=1)
    left_global_scores = tf.concat([padding, sorted_global_scores[:, :-1]],
                                   axis=1)

    # Masks indicating whether each candidate hyp is better or worse than its
    # left or right neighbor.
    is_better_than_right = tf.cast(
        tf.greater_equal(sorted_global_scores, right_global_scores),
        values_dtype)
    is_worse_than_right = 1.0 - is_better_than_right
    is_better_than_left = tf.cast(
        tf.greater(sorted_global_scores, left_global_scores), values_dtype)
    is_worse_than_left = 1.0 - is_better_than_left

    # Determine which hypotheses need to be merged.
    is_merge_source = tf.minimum(
        is_dup_of_left * is_worse_than_left +
        is_dup_of_right * is_worse_than_right, 1.0)
    is_left_merge_dest = is_dup_of_left * is_better_than_left
    is_right_merge_dest = is_dup_of_right * is_better_than_right
    is_merge_dest = tf.minimum(is_left_merge_dest + is_right_merge_dest, 1.0)
    # Mask of hyps unaffected by merging.
    is_unchanged = tf.maximum(1.0 - is_merge_source - is_merge_dest, 0.0)

    sorted_global_scores = (
        is_unchanged * sorted_global_scores +
        is_merge_source * BEST_SCORES_INIT + is_left_merge_dest *
        _log_sum_exp(left_global_scores, sorted_global_scores) +
        is_right_merge_dest *
        _log_sum_exp(right_global_scores, sorted_global_scores))
    # Set histories of deleted (merge source) hyps to zero.
    sorted_histories *= tf.cast(1.0 - is_merge_source, sorted_histories.dtype)

    # Put everything back in its original order and rank.
    global_score_values_out = history_unsort(sorted_global_scores)
    histories_out = history_unsort(sorted_histories)
    return global_score_values_out, histories_out