Exemple #1
0
    def _InputBatch(self):
        ret = py_utils.NestedMap()

        ret.bucket_keys = self._bucket_keys

        ret.src = py_utils.NestedMap()
        ret.src.ids = tf.cast(self._src_ids, dtype=tf.int32)
        ret.src.paddings = self._src_paddings

        ret.tgt = py_utils.NestedMap()
        ret.tgt.ids = self._tgt_ids
        ret.tgt.labels = tf.cast(self._tgt_labels, dtype=tf.int32)
        ret.tgt.weights = self._tgt_weights
        ret.tgt.paddings = self._tgt_paddings

        if (self.params.fprop_dtype is None
                or self.params.dtype == self.params.fprop_dtype):
            return ret

        def _Cast(v):
            if not v.dtype.is_floating:
                return v
            return tf.cast(v, self.params.fprop_dtype)

        return ret.Transform(_Cast)
def SequenceAppendToken(x, x_paddings, token, extend=False):
    """Appends <token> to sequence `x`.

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    token: The token to append (of type integer).
    extend: Whether to extend `x` along the length dimension, this must be true
      for any sequence length in `x` that is `x_len_max` or else an invalid
      sequence will be emitted.

  Returns:
    A tuple.
      - The new sequence, Tensor of shape [batch_size, x_len_max].
      - The new paddings, Tensor of shape [batch_size, x_len_max].
  """
    batch_size = py_utils.GetShape(x)[0]
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    if extend:
        x = tf.pad(x, [[0, 0], [0, 1]])
    # Mask all invalid entries of `x` to 0.
    x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype)
    # Append the <token> based on `x_len`.
    x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1),
                       tf.cast(tf.fill([batch_size], token), x.dtype),
                       py_utils.GetShape(x))
    x_paddings = 1 - tf.sequence_mask(x_len + 1,
                                      py_utils.GetShape(x)[1],
                                      x_paddings.dtype)
    return x, x_paddings
Exemple #3
0
    def _ProcessSingleInput(self, source_id, src, tgt):
        """Performs strings-to-ids on the given input pair via p.tokenizer_dict."""
        _, src_labels, src_paddings = self.StringsToIds(
            tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
        tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
            tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key)
        # Mask positions to 0 where padding is 1 for consistency. We do this because
        # tokenizer implementation may use EOS token to pad.
        src_labels = py_utils.ApplyPadding(src_paddings, src_labels)
        tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids)
        tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels)

        features = py_utils.NestedMap()
        features.src = py_utils.NestedMap()
        features.src.ids = src_labels
        # ids_indicator is 1 if and only if the output from tokenizer has a
        # non-padded id. Unlike weights, it will not mutate and can be used for
        # determining actual sequence length, for example.
        features.src.ids_indicator = 1 - src_paddings
        features.tgt = py_utils.NestedMap()
        features.tgt.ids = tgt_ids
        features.tgt.labels = tgt_labels
        features.tgt.ids_indicator = 1 - tgt_paddings

        src_task_id, tgt_task_id = self._GetTaskIds(source_id)
        # task_ids are padded with zeros.
        features.src.task_ids = tf.cast(features.src.ids_indicator,
                                        dtype=tf.int32) * src_task_id
        features.tgt.task_ids = tf.cast(features.tgt.ids_indicator,
                                        dtype=tf.int32) * tgt_task_id

        if not py_utils.use_tpu():
            features.src.strs = src
            features.tgt.strs = tgt
        return features.Transform(tf.squeeze)
    def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
        """Computes mean and variance over the valid data points in inputs."""
        inputs = py_utils.with_dependencies([
            py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
            py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
        ], inputs)
        rank = tf.rank(mask)
        reduce_over_dims = tf.range(0, rank - 1)
        sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                              reduce_over_dims)
        count_v = tf.reduce_sum(mask, reduce_over_dims)
        # Input shape is guaranteed to be a multiple of mask shape because the
        # inputs * mask op above was successfully broadcasted.
        mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
        count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_v = tf.tpu.cross_replica_sum(sum_v)
            count_v = tf.tpu.cross_replica_sum(count_v)

        count_v = tf.maximum(count_v, 1.0)
        mean = sum_v / count_v
        sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                               reduce_over_dims)

        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_vv = tf.tpu.cross_replica_sum(sum_vv)

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
        ], sum_vv / count_v)
        return mean, variance
Exemple #5
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(p.warmup_steps, tf.float32)
     linear_warmup = tf.minimum(1.0, current_step / warmup_steps)
     rsqrt_decay = tf.math.rsqrt(tf.maximum(current_step, warmup_steps))
     return p.model_dim**-0.5 * linear_warmup * rsqrt_decay
Exemple #6
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(
         p.warmup_examples / (p.batch_size * self._num_replicas),
         tf.float32)
     return tf.minimum((current_step + 1) * warmup_steps**-1.5,
                       (current_step + 1)**-0.5)
Exemple #7
0
 def _InputBatch(self):
     length = tf.reduce_prod(self.shape)
     counter = summary_utils.StatsCounter('CountingInputGenerator')
     new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length
     new_value = tf.stop_gradient(new_value)
     values = new_value + tf.range(length)
     shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32),
                                self.shape)
     targets = tf.reduce_sum(shaped_values, axis=0)
     return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
Exemple #8
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     p = self.params
     current_step = tf.cast(current_step, tf.float32)
     warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32)
     if p.decay_end is not None:
         current_step = tf.where(current_step < p.decay_end, current_step,
                                 tf.cast(p.decay_end, tf.float32))
     return p.model_dim**-0.5 * tf.minimum(
         (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
    def FProp(self, theta, inputs, paddings, domain_ids=None):
        """Applies data augmentation by randomly mask spectrum in inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: A tensor of shape [batch, time, freq, num_channels].
      paddings: A 0/1 tensor of shape [batch, time].
      domain_ids: input domain_ids of shape [batch, time].

    Returns:
      A pair of 2 tensors:

      - augmented_inputs: A tensor of shape [batch, time, freq, num_channels].
      - paddings: A 0/1 tensor of shape [batch, time].
    """
        p = self.params

        global_seed = None  # A tensor seed in case stateless random ops are needed.
        if p.use_input_dependent_random_seed:
            global_seed = _global_seed_from_inputs(inputs)

        batch_size, series_length, _, _ = py_utils.GetShape(inputs)
        if len(p.domain_ids) > 1:
            augmented_inputs = tf.zeros_like(inputs)
            original_inputs = inputs
            for i, domain_id in enumerate(p.domain_ids):
                augmented_domain = self._AugmentationNetwork(
                    series_length,
                    inputs,
                    paddings,
                    global_seed=global_seed,
                    domain_id_index=i)
                target_domain = tf.cast(tf.expand_dims(
                    tf.tile([domain_id], [batch_size]), -1),
                                        dtype=p.dtype)
                # [batch, time].
                domain_mask = tf.cast(tf.equal(domain_ids, target_domain),
                                      dtype=p.dtype)
                augmented_domain = self.EinsumBxycBxBxyc(
                    augmented_domain, domain_mask, name='einsum_domainmasking')
                original_inputs = self.EinsumBxycBxBxyc(
                    original_inputs,
                    1.0 - domain_mask,
                    name='einsum_domainmasking2')
                augmented_inputs = augmented_domain + augmented_inputs
            augmented_inputs = original_inputs + augmented_inputs
        else:
            augmented_inputs = self._AugmentationNetwork(
                series_length,
                inputs,
                paddings,
                global_seed=global_seed,
                domain_id_index=0)
        return augmented_inputs, paddings
Exemple #10
0
    def ComputePredictions(self, theta, batch):
        # pyformat: disable
        """Compute the model predictions.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      batch: A `.NestedMap`.

        - src: A `.NestedMap`.
          - ids: The source ids, ends in <eos>.
          - paddings: The source paddings.

        - tgt: A `.NestedMap`.
          - ids: The target ids, ends in <eos>.
          - paddings: The target paddings.

    Returns:
      A `.NestedMap`.
        - outputs: The contextualized output vectors of shape
          [batch_size, time_dim, model_dim].
        - tgt: A `.NestedMap` (optional, only during training).
          - ids: The canvas ids.
          - paddings: The canvas paddings.
          - target_indices: The target indices.
          - target_weights: The target weights.
    """
        # pyformat: enable
        p = self.params

        # TODO(williamchan): Currently, we only support KERMIT mode (i.e., no
        # encoder, unified architecture).
        assert not p.encoder

        # Sometimes src and tgt have different types. We reconcile here and use
        # int32.
        batch.src.ids = tf.cast(batch.src.ids, tf.int32)
        batch.tgt.ids = tf.cast(batch.tgt.ids, tf.int32)

        canvas_and_targets = self._CreateCanvasAndTargets(batch)
        batch = py_utils.NestedMap(tgt=py_utils.NestedMap(
            ids=canvas_and_targets.canvas,
            paddings=canvas_and_targets.canvas_paddings))

        predictions = super(InsertionModel,
                            self).ComputePredictions(theta, batch)

        if not self.do_eval:
            predictions.tgt = py_utils.NestedMap(
                ids=canvas_and_targets.canvas,
                paddings=canvas_and_targets.canvas_paddings,
                target_indices=canvas_and_targets.target_indices,
                target_weights=canvas_and_targets.target_weights)

        return predictions
    def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                          hyp_ids, hyp_lens, done_hyps, other_states,
                          pre_beam_search_step_callback,
                          post_beam_search_step_callback):
        """Extend greedy search hyps for one step.

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len].
      hyp_lens: Valid length of all the hyps. Tokens after eos ids are not
        counted.
      done_hyps: Whether or not a hyp has finished.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next greedy search step,
      (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states)
    """
        p = self.params
        # Increment hyp_lens by 1 if the hyp is not finished yet.
        hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32))

        bs_results, new_other_states = pre_beam_search_step_callback(
            theta, encoder_outputs, step_ids, other_states,
            1)  # num_hyps_per_beam
        new_step_ids = tf.math.argmax(bs_results.log_probs, 1)
        new_step_ids = tf.cast(new_step_ids, tf.int32)
        new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids))
        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        # Stash new_step_ids into the right slot.
        new_step_ids_1d = tf.reshape(new_step_ids, [-1])
        hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step,
                                                   new_step_ids_1d)
        # Update done_hyps if the current step_ids is the end of sequence token.
        done_hyps = tf.math.logical_or(
            done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id))

        return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                final_other_states)
 def UnstackFeatures(self, src_inputs, src_paddings):
     """Unstacks src_input and src_paddings based off stack height."""
     sh = self.params.stack_height
     bs, old_series_length, _, channels = py_utils.GetShape(src_inputs)
     unstacked_series_length = old_series_length * sh
     src_inputs = tf.reshape(src_inputs,
                             [bs, unstacked_series_length, -1, channels])
     content = 1 - src_paddings
     lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32)
     mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length)
     src_paddings = 1 - tf.cast(mask, tf.int32)
     return src_inputs, src_paddings
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.math.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.math.logical_and(states.consistent,
                                                 local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.math.log(probs), consistent
Exemple #14
0
    def FProp(self, theta, current_step):
        p = self.params
        current_step = tf.cast(current_step, tf.int64)
        interval_starts = [0] + p.boundaries
        values = []
        for interval_start, schedule, schedule_theta in zip(
                interval_starts, self.schedules, theta.schedules):
            relative_step = tf.maximum(
                tf.cast(0, current_step.dtype),
                current_step - tf.cast(interval_start, current_step.dtype))
            values.append(schedule.FProp(schedule_theta, relative_step))

        return py_utils.PiecewiseConstant(current_step, p.boundaries, values,
                                          values[0].dtype)
Exemple #15
0
 def _Value(self, current_step):
     """Returns the current clipping cap."""
     p = self.params
     start_step = tf.cast(p.start_step, tf.float32)
     end_step = tf.cast(p.end_step, tf.float32)
     current_step = tf.cast(current_step, tf.float32)
     steps_ratio = (
         tf.minimum(end_step - start_step, current_step - start_step) /
         (end_step - start_step))
     rmax_tensor = (steps_ratio * p.end_cap +
                    (1.0 - steps_ratio) * p.start_cap)
     return tf.cond(tf.less(current_step, p.start_step),
                    lambda: tf.cast(p.start_cap, tf.float32),
                    lambda: tf.cast(rmax_tensor, tf.float32))
Exemple #16
0
 def FProp(self, theta, current_step):
     """Returns the current learning rate decay."""
     params = self.params
     warmup_steps = tf.cast(params.decay_start * params.worker_replicas,
                            tf.float32)
     current_step = tf.cast(current_step, tf.float32)
     if params.decay_end is not None:
         current_step = tf.where(current_step < params.decay_end,
                                 current_step,
                                 tf.cast(params.decay_end, tf.float32))
     peak_learning_rate = (warmup_steps**-0.5)
     return (params.model_dim**-0.5) * tf.minimum(
         tf.minimum((current_step + 1),
                    (current_step + 1)**-0.5), peak_learning_rate)
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0):
    """Concats sequence `x` with sequence `y`.

  This function is length aware (based off the paddings).

  Args:
    x: A sequence of tokens of shape [batch_size, x_len_max].
    x_paddings: The paddings of `x`.
    y: A sequence of tokens of shape [batch_size, y_len_max].
    y_paddings: The paddings of `y`.
    pad: The <pad> token to fill the concatenated sequence (of type integer).

  Returns:
    A tuple.
      - Concatenation of `x` and `y` of shape
        [batch_size, x_len_max + y_len_max].
      - Paddings of the concatenation of shape
        [batch_size, x_len_max + y_len_max].
  """
    # Get the length (w/ eos).
    x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)
    y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32)

    batch_size = py_utils.GetShape(x)[0]
    y_len_max = py_utils.GetShape(y)[1]

    # Pad `x` with necessary <pad>.
    x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1)
    # Replace all <pad> with 0.
    x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0))

    # Compute the write indices of `y` in `xy`.
    indices = tf.stack([
        tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]),
        (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) +
         tf.expand_dims(x_len, 1)),
    ], 2)

    xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x))

    # We need to remap all <pad> to `pad`.
    xy = tf.where(
        tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0),
                tf.expand_dims(x_len + y_len, 1)), xy,
        tf.fill(py_utils.GetShape(xy), pad))
    xy_paddings = 1 - tf.sequence_mask(x_len + y_len,
                                       py_utils.GetShape(xy)[1],
                                       x_paddings.dtype)
    return xy, xy_paddings
Exemple #18
0
    def _ProcessBeamSearchDecodeOut(self, input_batch, encoder_outputs,
                                    decoder_outs):
        self.r1_shape = decoder_outs[0]
        self.r2_shape = decoder_outs[1]
        self.r3_shape = decoder_outs[2]
        tf.logging.info('r1_shape: %s', self.r1_shape)
        tf.logging.info('r2_shape: %s', self.r2_shape)
        tf.logging.info('r3_shape: %s', self.r3_shape)

        hyps = decoder_outs[3]
        prev_hyps = decoder_outs[4]
        done_hyps = decoder_outs[5]
        scores = decoder_outs[6]
        atten_probs = decoder_outs[7]
        eos_scores = decoder_outs[8]
        eos_atten_probs = decoder_outs[9]
        source_seq_lengths = decoder_outs[10]

        tlen = tf.cast(
            tf.round(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0),
            tf.int32)
        ret_dict = {
            'target_ids': input_batch.tgt.ids[:, 1:],
            'eval_weight': input_batch.eval_weight,
            'tlen': tlen,
            'hyps': hyps,
            'prev_hyps': prev_hyps,
            'done_hyps': done_hyps,
            'scores': scores,
            'atten_probs': atten_probs,
            'eos_scores': eos_scores,
            'eos_atten_probs': eos_atten_probs,
            'source_seq_lengths': source_seq_lengths,
        }
        return ret_dict
Exemple #19
0
        def Polynomial(x):
            """Polynomial function of x."""
            p = self.params
            x0, y0 = p.start
            x1, y1 = p.limit

            assert x0 < x1, '%s must be < %s' % (x0, x1)

            x0 = tf.cast(x0, dtype=x.dtype)
            x1 = tf.cast(x1, dtype=x.dtype)
            y0 = tf.cast(y0, dtype=x.dtype)
            y1 = tf.cast(y1, dtype=x.dtype)

            f_x = ((x - x0) / (x1 - x0))**p.power
            y = y0 + f_x * (y1 - y0)
            return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
def _global_seed_from_inputs(input_floats):
    """Generates a random seed tensor based on input floats and mode key.

  Args:
    input_floats: a set of float input tensors that are derived from the input
      data (for example, input tokens). The important thing is that these are
      usually different for each batch.

  Returns:
    A tensor of shape=[2] with integer seed tensors derived from the inputs.
  """
    timestamp = tf.math.floormod(tf.cast(tf.timestamp(), dtype=tf.int64),
                                 10000000)
    input_sum = tf.cast(tf.reduce_sum(tf.math.abs(input_floats)),
                        dtype=tf.int64)
    return tf.stack([timestamp + input_sum, timestamp - input_sum], axis=-1)
Exemple #21
0
 def restore(self, restored_tensors, restored_shapes):
     restored_tensor = restored_tensors[0]
     if restored_shapes is not None:
         restored_tensor = tf.reshape(restored_tensor, restored_shapes[0])
     return tf.assign(self.op,
                      tf.cast(restored_tensor, tf.bfloat16),
                      validate_shape=restored_shapes is None
                      and self.op.get_shape().is_fully_defined())
 def IncBy(self, delta):
   """Increment the counter by delta and return the new value."""
   # NOTE: We must ensure _value is computed (_var + 0) before
   # updating _var with delta.
   delta = tf.cast(delta, tf.int64)
   with tf.control_dependencies([self._value]):
     scalar(self._name, self._value)
     return tf.identity(tf.assign_add(self._var, delta))
    def _matmul_gather(self, values, axis=0, batch_major_state=True):
        """Returns values gathered.

    Args:
      values: Values to gather from.
      axis: Axis to gather on. Defaults to 0 (rows).
      batch_major_state: Whether the values to gather from use batch major or
        not. Defaults to True. For Transformer model, batch_major_state is set
        to False (time is the major dim).

    Returns:
      Gathered values.

    Raises:
      NotImplemented error if axis is not 0 nor 1.
    """

        dtype = values.dtype
        if dtype != tf.float32 and dtype != tf.bfloat16:
            values = tf.cast(values, tf.float32)

        if axis == 0:
            if values.shape.rank is not None and values.shape.rank > 2:
                if not batch_major_state:
                    values = tf.transpose(values, [1, 0, 2])
                results = tf.cast(
                    tf.gather(values, tf.cast(self._ids, tf.int32)), dtype)
                # pylint:disable=g-long-ternary
                return (tf.transpose(results, [1, 0, 2])
                        if not batch_major_state else results)
                # pylint:enable=g-long-ternary
            else:
                one_hot_ids = tf.one_hot(self._ids,
                                         self._ids_size,
                                         dtype=values.dtype)
                return tf.cast(tf.matmul(one_hot_ids, values), dtype)
        elif axis == 1:
            one_hot_ids = tf.one_hot(self._ids,
                                     self._ids_size,
                                     dtype=values.dtype,
                                     axis=0)
            return tf.cast(tf.matmul(values, one_hot_ids), dtype)
        else:
            raise NotImplementedError("Only row/col-wise gather implemented.")
 def _Apply():
     if self.params.use_bf16_gradients_ar:
         return optimizer.apply_gradients(
             [(tf.cast(g, tf.float32), v)
              for (v, g) in var_grad.Flatten()],
             name='meta_backprop')
     else:
         return optimizer.apply_gradients(
             [(g, v) for (v, g) in var_grad.Flatten()],
             name='meta_backprop')
Exemple #25
0
 def FProp(self, theta, current_step):
     p = self.params
     assert p.total_steps > 0
     assert p.initial_value > p.final_value
     with tf.name_scope(p.name):
         decay_gap = p.initial_value - p.final_value
         return p.final_value + 0.5 * decay_gap * (1 + tf.cos(
             math.pi *
             tf.minimum(1.0,
                        tf.cast(current_step, tf.float32) / p.total_steps)))
Exemple #26
0
def MakeCausalPadding(seq_len, block_size, left_context, right_context):
  """Makes the causal padding tensor for a full sequence.

  Args:
    seq_len: int or scalar int tensor. Sequence length.
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size.
    right_context: int. Right context size.

  Returns:
    A tensor of [num_blocks, block_size, context_size] taking values in {0, 1},
    where context_size = block_size + (left_context - 1) + right_context.
    Element b, i, j is zero if in the b-th block, the i-th frame can access
    the j-th frame in the context.
  """
  seq_len = py_utils.with_dependencies([
      py_utils.assert_greater_equal(
          seq_len, 1, message='seq_len must be at least 1')
  ], seq_len)

  num_blocks = (seq_len + block_size - 1) // block_size
  context_size = block_size + (left_context - 1) + right_context

  # [num_blocks, block_size]: source positions in the original sequence.
  src_positions = tf.reshape(
      tf.range(num_blocks * block_size), [num_blocks, block_size])
  # [num_blocks,]: source positions at the start of each block.
  block_start_positions = tf.range(0, num_blocks * block_size, block_size)
  # [context_size]:  positions relative to the block start.
  relative_context_positions = tf.range(context_size) - (left_context - 1)

  # [num_blocks, context_size]: target positions in the original sequence.
  tgt_positions = (
      block_start_positions[:, tf.newaxis] +
      relative_context_positions[tf.newaxis, :])
  # [num_blocks, block_size, context_size]: position differences between source-
  # target pairs.
  position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:,
                                                                  tf.newaxis, :]
  # [num_blocks, block_size, context_size]: if attention is allowed between
  # source-target pairs.
  valid_atten = tf.math.logical_and(-right_context <= position_diff,
                                    position_diff < left_context)

  # [num_blocks, block_size]: if the source position is valid, not padded.
  valid_src = src_positions < seq_len
  # [num_blocks, context_size]: if the target position is valid, not padded.
  valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len)

  valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis],
                                     valid_tgt[:, tf.newaxis, :])

  padding = 1.0 - tf.cast(valid_atten, dtype=tf.float32)

  return padding
Exemple #27
0
    def ApplyClippingWithState(self, state, x):
        """Applies clipping to x.

    Args:
      state: Clipping state.
      x: Input tensor to clip.
    Returns:
      Clipped (or identity) x.
    """
        cap = tf.cast(state, x.dtype)
        return tf.clip_by_value(x, -cap, cap)
Exemple #28
0
  def FProp(self, theta, inputs):
    """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
    p = self.params
    with tf.name_scope(p.name):
      computation_cost.Add(
          self, 'flops',
          tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
          tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims), tf.int64) *
          2)
      return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims,
                                     p.output_dims)
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None):
    """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep."""
    seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64
    if p.is_inference and p.random_seed is None:
        # Unlike tf.random*, stateless random ops are completely determined by the
        # passed-in seeds. This means at inference time the same inputs will produce
        # the same outputs, even if the model is supposed to have randomness such as
        # dropout during inference. We inject additional randomness only during
        # inference if the graph is exported with random_seed=None as a workaround.
        return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)

    with tf.name_scope('op_seed') as scope:
        global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype)
        step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype)
        seeds = tf.stack([global_step, step_seed])

        if p.random_seed is not None:
            seeds += p.random_seed
        if op_seed is not None:
            seeds += op_seed
        return seeds
Exemple #30
0
        def _DerivePaddingsAndIds(src_ids, tgt_labels):
            """tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended."""
            tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0)
            src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32)
            tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32)
            tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32)

            bucket_key = tf.cast(
                tf.maximum(tf.reduce_sum(1.0 - src_paddings),
                           tf.reduce_sum(1.0 - tgt_paddings)), tf.int32)

            return src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key