Esempio n. 1
0
  def _ProcessSingleInput(self, source_id, src, tgt):
    """Performs strings-to-ids on the given input pair via p.tokenizer_dict."""
    _, src_labels, src_paddings = self.StringsToIds(
        tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
    tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
        tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key)
    # Mask positions to 0 where padding is 1 for consistency. We do this because
    # tokenizer implementation may use EOS token to pad.
    src_labels = py_utils.ApplyPadding(src_paddings, src_labels)
    tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids)
    tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels)

    features = py_utils.NestedMap()
    features.src = py_utils.NestedMap()
    features.src.ids = src_labels
    # ids_indicator is 1 if and only if the output from tokenizer has a
    # non-padded id. Unlike weights, it will not mutate and can be used for
    # determining actual sequence length, for example.
    features.src.ids_indicator = 1 - src_paddings
    features.tgt = py_utils.NestedMap()
    features.tgt.ids = tgt_ids
    features.tgt.labels = tgt_labels
    features.tgt.ids_indicator = 1 - tgt_paddings

    src_task_id, tgt_task_id = self._GetTaskIds(source_id)
    # task_ids are padded with zeros.
    features.src.task_ids = tf.cast(
        features.src.ids_indicator, dtype=tf.int32) * src_task_id
    features.tgt.task_ids = tf.cast(
        features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id

    if not py_utils.use_tpu():
      features.src.strs = src
      features.tgt.strs = tgt
    return features.Transform(tf.squeeze)
Esempio n. 2
0
def SequenceAppendToken(x, x_paddings, token, extend=False):
    """Appends <token> to sequence `x`.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    token: The token to append (of type integer).
    extend: Whether to extend `x` along the length dimension, this must be true
      for any sequence length in `x` that is `x_len_max` or else an invalid
      sequence will be emitted.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    batch_size = py_utils.GetShape(x)[0]
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    if extend:
        x = tf.pad(x, [[0, 0], [0, 1]])
    # Mask all invalid entries of `x` to 0.
    x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype)
    # Append the <token> based on `x_len`.
    x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1),
                       tf.cast(tf.fill([batch_size], token), x.dtype),
                       py_utils.GetShape(x))
    x_paddings = 1 - tf.sequence_mask(x_len + 1,
                                      py_utils.GetShape(x)[1],
                                      x_paddings.dtype)
    return x, x_paddings
Esempio n. 3
0
  def _InputBatch(self):
    ret = py_utils.NestedMap()

    ret.bucket_keys = self._bucket_keys

    ret.src = py_utils.NestedMap()
    ret.src.ids = tf.cast(self._src_ids, dtype=tf.int32)
    ret.src.paddings = self._src_paddings

    ret.tgt = py_utils.NestedMap()
    ret.tgt.ids = self._tgt_ids
    ret.tgt.labels = tf.cast(self._tgt_labels, dtype=tf.int32)
    ret.tgt.weights = self._tgt_weights
    ret.tgt.paddings = self._tgt_paddings

    if (self.params.fprop_dtype is None or
        self.params.dtype == self.params.fprop_dtype):
      return ret

    def _Cast(v):
      if not v.dtype.is_floating:
        return v
      return tf.cast(v, self.params.fprop_dtype)

    return ret.Transform(_Cast)
Esempio n. 4
0
  def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
    """Computes mean and variance over the valid data points in inputs."""
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    rank = tf.rank(mask)
    reduce_over_dims = tf.range(0, rank - 1)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims)
    count_v = tf.reduce_sum(mask, reduce_over_dims)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
    count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
      sum_v = tf.tpu.cross_replica_sum(sum_v)
      count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
      sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
Esempio n. 5
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(p.warmup_steps, tf.float32)
     linear_warmup = tf.minimum(1.0, current_step / warmup_steps)
     rsqrt_decay = tf.math.rsqrt(tf.maximum(current_step, warmup_steps))
     return p.model_dim**-0.5 * linear_warmup * rsqrt_decay
Esempio n. 6
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(
         p.warmup_examples / (p.batch_size * self._num_replicas),
         tf.float32)
     return tf.minimum((current_step + 1) * warmup_steps**-1.5,
                       (current_step + 1)**-0.5)
 def _InputBatch(self):
   length = tf.reduce_prod(self.shape)
   counter = summary_utils.StatsCounter('CountingInputGenerator')
   new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length
   new_value = tf.stop_gradient(new_value)
   values = new_value + tf.range(length)
   shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape)
   targets = tf.reduce_sum(shaped_values, axis=0)
   return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
Esempio n. 8
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32)
     if p.decay_end is not None:
         current_step = tf.where(current_step < p.decay_end, current_step,
                                 tf.cast(p.decay_end, tf.float32))
     return p.model_dim**-0.5 * tf.minimum(
         (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
Esempio n. 9
0
    def ComputePredictions(self, theta, batch):
        # pyformat: disable
        """Compute the model predictions.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `.NestedMap`.
        - outputs: The contextualized output vectors of shape
          [batch_size, time_dim, model_dim].
        - tgt: A `.NestedMap` (optional, only during training).
          - ids: The canvas ids.
          - paddings: The canvas paddings.
          - target_indices: The target indices.
          - target_weights: The target weights.
    """
        # pyformat: enable
        p = self.params

        # TODO(williamchan): Currently, we only support KERMIT mode (i.e., no
        # encoder, unified architecture).
        assert not p.encoder

        # Sometimes src and tgt have different types. We reconcile here and use
        # int32.
        batch.src.ids = tf.cast(batch.src.ids, tf.int32)
        batch.tgt.ids = tf.cast(batch.tgt.ids, tf.int32)

        canvas_and_targets = self._CreateCanvasAndTargets(batch)
        batch = py_utils.NestedMap(tgt=py_utils.NestedMap(
            ids=canvas_and_targets.canvas,
            paddings=canvas_and_targets.canvas_paddings))

        predictions = super(InsertionModel,
                            self).ComputePredictions(theta, batch)

        if not self.do_eval:
            predictions.tgt = py_utils.NestedMap(
                ids=canvas_and_targets.canvas,
                paddings=canvas_and_targets.canvas_paddings,
                target_indices=canvas_and_targets.target_indices,
                target_weights=canvas_and_targets.target_weights)

        return predictions
    def FProp(self, theta, inputs, paddings, domain_ids=None):
        """Applies data augmentation by randomly mask spectrum in inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: A tensor of shape [batch, time, freq, num_channels].
      paddings: A 0/1 tensor of shape [batch, time].
      domain_ids: input domain_ids of shape [batch, time].

    Returns:
      A pair of 2 tensors:

      - augmented_inputs: A tensor of shape [batch, time, freq, num_channels].
      - paddings: A 0/1 tensor of shape [batch, time].
    """
        p = self.params

        global_seed = None  # A tensor seed in case stateless random ops are needed.
        if p.use_input_dependent_random_seed:
            global_seed = _global_seed_from_inputs(inputs)

        batch_size, series_length, _, _ = py_utils.GetShape(inputs)
        if len(p.domain_ids) > 1:
            augmented_inputs = tf.zeros_like(inputs)
            original_inputs = inputs
            for i, domain_id in enumerate(p.domain_ids):
                augmented_domain = self._AugmentationNetwork(
                    series_length,
                    inputs,
                    paddings,
                    global_seed=global_seed,
                    domain_id_index=i)
                target_domain = tf.cast(tf.expand_dims(
                    tf.tile([domain_id], [batch_size]), -1),
                                        dtype=p.dtype)
                # [batch, time].
                domain_mask = tf.cast(tf.equal(domain_ids, target_domain),
                                      dtype=p.dtype)
                augmented_domain = self.EinsumBxycBxBxyc(
                    augmented_domain, domain_mask, name='einsum_domainmasking')
                original_inputs = self.EinsumBxycBxBxyc(
                    original_inputs,
                    1.0 - domain_mask,
                    name='einsum_domainmasking2')
                augmented_inputs = augmented_domain + augmented_inputs
            augmented_inputs = original_inputs + augmented_inputs
        else:
            augmented_inputs = self._AugmentationNetwork(
                series_length,
                inputs,
                paddings,
                global_seed=global_seed,
                domain_id_index=0)
        return augmented_inputs, paddings
    def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                          hyp_ids, hyp_lens, done_hyps, other_states,
                          pre_beam_search_step_callback,
                          post_beam_search_step_callback):
        """Extend greedy search hyps for one step.

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len].
      hyp_lens: Valid length of all the hyps. Tokens after eos ids are not
        counted.
      done_hyps: Whether or not a hyp has finished.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next greedy search step,
      (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states)
    """
        p = self.params
        # Increment hyp_lens by 1 if the hyp is not finished yet.
        hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32))

        bs_results, new_other_states = pre_beam_search_step_callback(
            theta, encoder_outputs, step_ids, other_states,
            1)  # num_hyps_per_beam
        new_step_ids = tf.math.argmax(bs_results.log_probs, 1)
        new_step_ids = tf.cast(new_step_ids, tf.int32)
        new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids))
        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        # Stash new_step_ids into the right slot.
        new_step_ids_1d = tf.reshape(new_step_ids, [-1])
        hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step,
                                                   new_step_ids_1d)
        # Update done_hyps if the current step_ids is the end of sequence token.
        done_hyps = tf.math.logical_or(
            done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id))

        return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                final_other_states)
 def UnstackFeatures(self, src_inputs, src_paddings):
     """Unstacks src_input and src_paddings based off stack height."""
     sh = self.params.stack_height
     bs, old_series_length, _, channels = py_utils.GetShape(src_inputs)
     unstacked_series_length = old_series_length * sh
     src_inputs = tf.reshape(src_inputs,
                             [bs, unstacked_series_length, -1, channels])
     content = 1 - src_paddings
     lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32)
     mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length)
     src_paddings = 1 - tf.cast(mask, tf.int32)
     return src_inputs, src_paddings
Esempio n. 13
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     params = self.params
     warmup_steps = tf.cast(params.decay_start * params.worker_replicas,
                            tf.float32)
     current_step = tf.cast(current_step, tf.float32)
     if params.decay_end is not None:
         current_step = tf.where(current_step < params.decay_end,
                                 current_step,
                                 tf.cast(params.decay_end, tf.float32))
     peak_learning_rate = (warmup_steps**-0.5)
     return (params.model_dim**-0.5) * tf.minimum(
         tf.minimum((current_step + 1),
                    (current_step + 1)**-0.5), peak_learning_rate)
Esempio n. 14
0
    def FProp(self, theta, current_step):
        p = self.params
        current_step = tf.cast(current_step, tf.int64)
        interval_starts = [0] + p.boundaries
        values = []
        for interval_start, schedule, schedule_theta in zip(
                interval_starts, self.schedules, theta.schedules):
            relative_step = tf.maximum(
                tf.cast(0, current_step.dtype),
                current_step - tf.cast(interval_start, current_step.dtype))
            values.append(schedule.FProp(schedule_theta, relative_step))

        return py_utils.PiecewiseConstant(current_step, p.boundaries, values,
                                          values[0].dtype)
Esempio n. 15
0
 def _Value(self, current_step):
     """Returns the current clipping cap."""
     p = self.params
     start_step = tf.cast(p.start_step, tf.float32)
     end_step = tf.cast(p.end_step, tf.float32)
     current_step = tf.cast(current_step, tf.float32)
     steps_ratio = (
         tf.minimum(end_step - start_step, current_step - start_step) /
         (end_step - start_step))
     rmax_tensor = (steps_ratio * p.end_cap +
                    (1.0 - steps_ratio) * p.start_cap)
     return tf.cond(tf.less(current_step, p.start_step),
                    lambda: tf.cast(p.start_cap, tf.float32),
                    lambda: tf.cast(rmax_tensor, tf.float32))
Esempio n. 16
0
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0):
    """Concats sequence `x` with sequence `y`.

  This function is length aware (based off the paddings).

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    y: A sequence of tokens of shape [batch_size, y_len_max].
    y_paddings: The paddings of `y`.
    pad: The <pad> token to fill the concatenated sequence (of type integer).

  Returns:
    A tuple.
      - Concatenation of `x` and `y` of shape
        [batch_size, x_len_max + y_len_max].
      - Paddings of the concatenation of shape
        [batch_size, x_len_max + y_len_max].
  """
    # Get the length (w/ eos).
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32)

    batch_size = py_utils.GetShape(x)[0]
    y_len_max = py_utils.GetShape(y)[1]

    # Pad `x` with necessary <pad>.
    x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1)
    # Replace all <pad> with 0.
    x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0))

    # Compute the write indices of `y` in `xy`.
    indices = tf.stack([
        tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]),
        (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) +
         tf.expand_dims(x_len, 1)),
    ], 2)

    xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x))

    # We need to remap all <pad> to `pad`.
    xy = tf.where(
        tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0),
                tf.expand_dims(x_len + y_len, 1)), xy,
        tf.fill(py_utils.GetShape(xy), pad))
    xy_paddings = 1 - tf.sequence_mask(x_len + y_len,
                                       py_utils.GetShape(xy)[1],
                                       x_paddings.dtype)
    return xy, xy_paddings
def _global_seed_from_inputs(input_floats):
    """Generates a random seed tensor based on input floats and mode key.

  Args:
    input_floats: a set of float input tensors that are derived from the input
      data (for example, input tokens). The important thing is that these are
      usually different for each batch.

  Returns:
    A tensor of shape=[2] with integer seed tensors derived from the inputs.
  """
    timestamp = tf.math.floormod(tf.cast(tf.timestamp(), dtype=tf.int64),
                                 10000000)
    input_sum = tf.cast(tf.reduce_sum(tf.math.abs(input_floats)),
                        dtype=tf.int64)
    return tf.stack([timestamp + input_sum, timestamp - input_sum], axis=-1)
Esempio n. 18
0
    def _ProcessBeamSearchDecodeOut(self, input_batch, encoder_outputs,
                                    decoder_outs):
        self.r1_shape = decoder_outs[0]
        self.r2_shape = decoder_outs[1]
        self.r3_shape = decoder_outs[2]
        tf.logging.info('r1_shape: %s', self.r1_shape)
        tf.logging.info('r2_shape: %s', self.r2_shape)
        tf.logging.info('r3_shape: %s', self.r3_shape)

        hyps = decoder_outs[3]
        prev_hyps = decoder_outs[4]
        done_hyps = decoder_outs[5]
        scores = decoder_outs[6]
        atten_probs = decoder_outs[7]
        eos_scores = decoder_outs[8]
        eos_atten_probs = decoder_outs[9]
        source_seq_lengths = decoder_outs[10]

        tlen = tf.cast(
            tf.round(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0),
            tf.int32)
        ret_dict = {
            'target_ids': input_batch.tgt.ids[:, 1:],
            'eval_weight': input_batch.eval_weight,
            'tlen': tlen,
            'hyps': hyps,
            'prev_hyps': prev_hyps,
            'done_hyps': done_hyps,
            'scores': scores,
            'atten_probs': atten_probs,
            'eos_scores': eos_scores,
            'eos_atten_probs': eos_atten_probs,
            'source_seq_lengths': source_seq_lengths,
        }
        return ret_dict
Esempio n. 19
0
        def Polynomial(x):
            """Polynomial function of x."""
            p = self.params
            x0, y0 = p.start
            x1, y1 = p.limit

            assert x0 < x1, '%s must be < %s' % (x0, x1)

            x0 = tf.cast(x0, dtype=x.dtype)
            x1 = tf.cast(x1, dtype=x.dtype)
            y0 = tf.cast(y0, dtype=x.dtype)
            y1 = tf.cast(y1, dtype=x.dtype)

            f_x = ((x - x0) / (x1 - x0))**p.power
            y = y0 + f_x * (y1 - y0)
            return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
Esempio n. 20
0
      def ApplyBias():
        """Bias and update log_probs and consistent."""

        def TileForBeamAndFlatten(tensor):
          tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
          tensor = tf.tile(
              tensor, [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
          tgt_batch = tf.shape(step_ids)[0]  # num_hyps_per_beam*src_batch
          return tf.reshape(tensor, [tgt_batch])

        # Consistent if step_ids == labels from previous step
        # TODO(navari): Consider updating consistent only if weights > 0. Then
        # re-evaluate the need for bias_only_if_consistent=True.
        # Note that prev_label is incorrrect for step 0 but is overridden later
        prev_label = TileForBeamAndFlatten(
            tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
        is_step0 = tf.equal(time_step, 0)
        local_consistence = tf.math.logical_or(
            is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
        consistent = tf.math.logical_and(states.consistent, local_consistence)

        # get label, weight slices corresponding to current time_step
        label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
        weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1))
        if p.bias_only_if_consistent:
          weight = weight * tf.cast(consistent, p.dtype)

        # convert from dense label to sparse label probs
        vocab_size = tf.shape(bs_results.log_probs)[1]
        uncertainty = tf.constant(
            1e-10, p.dtype)  # avoid 0 probs which may cause issues with log
        label_probs = tf.one_hot(
            label,
            vocab_size,
            on_value=1 - uncertainty,
            off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
            dtype=p.dtype)  # [tgt_batch, vocab_size]
        pred_probs = tf.exp(bs_results.log_probs)

        # interpolate predicted probs and label probs
        weight = tf.expand_dims(weight, 1)
        probs = py_utils.with_dependencies([
            py_utils.assert_less_equal(weight, 1.),
            py_utils.assert_greater_equal(weight, 0.)
        ], (1.0 - weight) * pred_probs + weight * label_probs)
        return tf.math.log(probs), consistent
Esempio n. 21
0
 def IncBy(self, delta):
   """Increment the counter by delta and return the new value."""
   # NOTE: We must ensure _value is computed (_var + 0) before
   # updating _var with delta.
   delta = tf.cast(delta, tf.int64)
   with tf.control_dependencies([self._value]):
     scalar(self._name, self._value)
     return tf.identity(tf.assign_add(self._var, delta))
 def restore(self, restored_tensors, restored_shapes):
   restored_tensor = restored_tensors[0]
   if restored_shapes is not None:
     restored_tensor = tf.reshape(restored_tensor, restored_shapes[0])
   return tf.assign(
       self.op,
       tf.cast(restored_tensor, tf.bfloat16),
       validate_shape=restored_shapes is None and
       self.op.get_shape().is_fully_defined())
    def _matmul_gather(self, values, axis=0, batch_major_state=True):
        """Returns values gathered.

    Args:
      values: Values to gather from.
      axis: Axis to gather on. Defaults to 0 (rows).
      batch_major_state: Whether the values to gather from use batch major or
        not. Defaults to True. For Transformer model, batch_major_state is set
        to False (time is the major dim).

    Returns:
      Gathered values.

    Raises:
      NotImplemented error if axis is not 0 nor 1.
    """

        dtype = values.dtype
        if dtype != tf.float32 and dtype != tf.bfloat16:
            values = tf.cast(values, tf.float32)

        if axis == 0:
            if values.shape.rank is not None and values.shape.rank > 2:
                if not batch_major_state:
                    values = tf.transpose(values, [1, 0, 2])
                results = tf.cast(
                    tf.gather(values, tf.cast(self._ids, tf.int32)), dtype)
                # pylint:disable=g-long-ternary
                return (tf.transpose(results, [1, 0, 2])
                        if not batch_major_state else results)
                # pylint:enable=g-long-ternary
            else:
                one_hot_ids = tf.one_hot(self._ids,
                                         self._ids_size,
                                         dtype=values.dtype)
                return tf.cast(tf.matmul(one_hot_ids, values), dtype)
        elif axis == 1:
            one_hot_ids = tf.one_hot(self._ids,
                                     self._ids_size,
                                     dtype=values.dtype,
                                     axis=0)
            return tf.cast(tf.matmul(values, one_hot_ids), dtype)
        else:
            raise NotImplementedError("Only row/col-wise gather implemented.")
Esempio n. 24
0
 def _Apply():
     if self.params.use_bf16_gradients_ar:
         return optimizer.apply_gradients(
             [(tf.cast(g, tf.float32), v)
              for (v, g) in var_grad.Flatten()],
             name='meta_backprop')
     else:
         return optimizer.apply_gradients(
             [(g, v) for (v, g) in var_grad.Flatten()],
             name='meta_backprop')
Esempio n. 25
0
 def FProp(self, theta, current_step):
     p = self.params
     assert p.total_steps > 0
     assert p.initial_value > p.final_value
     with tf.name_scope(p.name):
         decay_gap = p.initial_value - p.final_value
         return p.final_value + 0.5 * decay_gap * (1 + tf.cos(
             math.pi *
             tf.minimum(1.0,
                        tf.cast(current_step, tf.float32) / p.total_steps)))
def MakeCausalPadding(seq_len, block_size, left_context, right_context):
    """Makes the causal padding tensor for a full sequence.

  Args:
    seq_len: int or scalar int tensor. Sequence length.
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.

  Returns:
    A tensor of [num_blocks, block_size, context_size] taking values in {0, 1},
    where context_size = block_size + (left_context - 1) + right_context.
    Element b, i, j is zero if in the b-th block, the i-th frame can access
    the j-th frame in the context.
  """
    seq_len = py_utils.with_dependencies([
        py_utils.assert_greater_equal(
            seq_len, 1, message='seq_len must be at least 1')
    ], seq_len)

    num_blocks = (seq_len + block_size - 1) // block_size
    context_size = block_size + (left_context - 1) + right_context

    # [num_blocks, block_size]: source positions in the original sequence.
    src_positions = tf.reshape(tf.range(num_blocks * block_size),
                               [num_blocks, block_size])
    # [num_blocks,]: source positions at the start of each block.
    block_start_positions = tf.range(0, num_blocks * block_size, block_size)
    # [context_size]:  positions relative to the block start.
    relative_context_positions = tf.range(context_size) - (left_context - 1)

    # [num_blocks, context_size]: target positions in the original sequence.
    tgt_positions = (block_start_positions[:, tf.newaxis] +
                     relative_context_positions[tf.newaxis, :])
    # [num_blocks, block_size, context_size]: position differences between source-
    # target pairs.
    position_diff = src_positions[:, :,
                                  tf.newaxis] - tgt_positions[:, tf.newaxis, :]
    # [num_blocks, block_size, context_size]: if attention is allowed between
    # source-target pairs.
    valid_atten = tf.math.logical_and(-right_context <= position_diff,
                                      position_diff < left_context)

    # [num_blocks, block_size]: if the source position is valid, not padded.
    valid_src = src_positions < seq_len
    # [num_blocks, context_size]: if the target position is valid, not padded.
    valid_tgt = tf.math.logical_and(0 <= tgt_positions,
                                    tgt_positions < seq_len)

    valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis],
                                       valid_tgt[:, tf.newaxis, :])

    padding = 1.0 - tf.cast(valid_atten, dtype=tf.float32)

    return padding
Esempio n. 27
0
    def FProp(self, theta, inputs):
        """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
        p = self.params
        with tf.name_scope(p.name):
            computation_cost.Add(
                self, 'flops',
                tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
                tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims),
                        tf.int64) * 2)
            return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims,
                                           p.output_dims)
Esempio n. 28
0
    def ApplyClippingWithState(self, state, x):
        """Applies clipping to x.

    Args:
      state: Clipping state.
      x: Input tensor to clip.
    Returns:
      Clipped (or identity) x.
    """
        cap = tf.cast(state, x.dtype)
        return tf.clip_by_value(x, -cap, cap)
Esempio n. 29
0
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None):
  """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep."""
  seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64
  if p.is_inference and p.random_seed is None:
    # Unlike tf.random*, stateless random ops are completely determined by the
    # passed-in seeds. This means at inference time the same inputs will produce
    # the same outputs, even if the model is supposed to have randomness such as
    # dropout during inference. We inject additional randomness only during
    # inference if the graph is exported with random_seed=None as a workaround.
    return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)

  with tf.name_scope('op_seed') as scope:
    global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype)
    step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype)
    seeds = tf.stack([global_step, step_seed])

    if p.random_seed is not None:
      seeds += p.random_seed
    if op_seed is not None:
      seeds += op_seed
    return seeds
    def _TimeWarp(self,
                  inputs,
                  seq_lengths,
                  global_seed,
                  dtype=tf.float32,
                  domain_id_index=0):
        """Applies time warping with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      seq_lengths: The actual sequence lengths which mask been sampled of shape
        (batch_size,).
      global_seed: an integer seed tensor for stateless random ops.
      dtype: Data type.
      domain_id_index: Domain ID index.

    Returns:
      Inputs with random time warping applied.
    """
        p = self.params
        batch_size, time_length, _, _ = py_utils.GetShape(inputs)

        # Get parameters for warping.
        time_warp_max_frames = p.time_warp_max_frames[domain_id_index]
        max_ratio = p.time_warp_max_ratio[domain_id_index]
        time_warp_bound = p.time_warp_bound[domain_id_index]
        assert time_warp_bound in ('static', 'dynamic')

        # If maximum warp length is zero, do nothing.
        if ((time_warp_max_frames == 0 and time_warp_bound == 'static')
                or max_ratio <= 0.0):
            return inputs
        seq_lengths = tf.cast(seq_lengths, tf.int32)

        # Discard upper-bound on time-warp frames when
        # dynamic time warping is used.
        if time_warp_bound == 'dynamic':
            time_warp_max_frames = None

        # Create warping matrix in time direction and apply
        warp_matrix = self._GetWarpMatrix(batch_size,
                                          choose_range=seq_lengths,
                                          matrix_size=time_length,
                                          global_seed=global_seed,
                                          max_warp_frames=time_warp_max_frames,
                                          dtype=dtype,
                                          max_ratio=max_ratio)

        return self.EinsumBxycBzxBzyc(inputs,
                                      warp_matrix,
                                      name='einsum_forwarping')