Ejemplo n.º 1
0
    def _Proc(record):
      """Parses a serialized tf.Example record."""
      outputs = [
          ('inputs', tf.io.VarLenFeature(tf.int64)),
          ('targets', tf.io.VarLenFeature(tf.int64)),
          # Default eval weight to 1.0
          ('eval_weight',
           tf.io.FixedLenFeature([], tf.float32, default_value=1.0)),
      ]
      features = tf.io.parse_single_example(record, dict(outputs))
      for k, v in six.iteritems(features):
        if k != 'eval_weight':
          features[k] = v.values
        else:
          eval_weight = v

      src_ids = features['inputs']
      tgt_labels = features['targets']

      # Derive trivial segmentation for unpacked input.
      src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds(
          src_ids, tgt_labels)

      src_len = tf.shape(src_ids)[0]
      tgt_len = tf.shape(tgt_ids)[0]
      src_pos = tf.range(src_len, dtype=tf.int32)
      src_seg = tf.zeros_like(src_paddings)
      tgt_pos = tf.range(tgt_len, dtype=tf.int32)
      tgt_seg = tf.zeros_like(tgt_paddings)

      return [
          src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights,
          src_pos, src_seg, tgt_pos, tgt_seg, eval_weight
      ], bucket_key
 def Step(recurrent_theta, state0, inputs):
     """Computes one decoder step."""
     del inputs
     with tf.name_scope('single_sampler_step'):
         # Compute logits and states.
         bs_result, bs_state1 = pre_step_callback(
             recurrent_theta.theta,
             recurrent_theta.encoder_outputs,
             tf.expand_dims(state0.ids, 1),  # [batch, 1].
             state0.bs_state,
             num_hyps_per_beam=1)
         batch = tf.shape(bs_result.log_probs)[0]
         state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
         state1.logits = bs_result.log_probs
         # Sample ids from logits. [batch].
         state1.ids = tf.reshape(
             tf.random.stateless_categorical(
                 state1.logits / p.temperature,
                 num_samples=1,
                 seed=tf.stack(
                     [recurrent_theta.random_seed, state0.timestep]),
                 dtype=state0.ids.dtype,
                 name='sample_next_id'), [batch])
         if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
             state1.ids = tf.where(
                 tf.math.logical_and(
                     bs_result.is_last_chunk,
                     tf.equal(state1.ids, p.target_eoc_id)),
                 tf.fill(tf.shape(state1.ids), p.target_eos_id),
                 state1.ids)
         state1.bs_state = post_step_callback(
             recurrent_theta.theta, recurrent_theta.encoder_outputs,
             state1.ids, bs_state1)
     return state1, py_utils.NestedMap()
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
def ComputeSplits(batch_size, num_splits):
    """Creates a tensor of size num_splits of number of values per split.

  Assigns each split floor(batch_size/num_splits) and round-robins
  the remainder (if any) to each split.

  Example::

    batch_size: [5]
    num_splits: 3
    returns: [2, 2, 1]

  Args:
    batch_size: tensor of rank 0, size of tensor to be split
    num_splits: number of splits to split tensor into
  Returns:
    tensor of length num_splits containing sizes of each split
  """
    values = tf.tile(tf.div([batch_size], num_splits),
                     tf.constant([num_splits], dtype=tf.int32))
    mods = tf.tile(tf.constant([1]), tf.math.floormod([batch_size],
                                                      num_splits))
    zeros = tf.tile(tf.constant([0]),
                    tf.subtract(tf.shape(values), tf.shape(mods)))
    mods = tf.concat([mods, zeros], 0)
    ret = tf.add(values, mods)
    # for some reason TF erases shape information if num_splits is 1
    if num_splits == 1:
        ret.set_shape([1])
    return ret
Ejemplo n.º 5
0
  def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
    """Computes mean and variance over the valid data points in inputs."""
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    rank = tf.rank(mask)
    reduce_over_dims = tf.range(0, rank - 1)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims)
    count_v = tf.reduce_sum(mask, reduce_over_dims)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
    count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
      sum_v = tf.tpu.cross_replica_sum(sum_v)
      count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
      sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
Ejemplo n.º 6
0
    def FProp(self, theta, input_batch):
        """Encodes source as represented by `inputs` and `paddings`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:
        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:

      - encoded: The encoded features, a tensor of shape [time, batch, depth]
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
    """
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Now the rnn layers.
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            self._emb_out = xs
            ps = paddings
            # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell
            # with the same cc_schedule so that the RNN layer output is within
            # clipping range.
            xs = self.rnn[0].FProp(theta.rnn[0], xs, ps)
            xs = self.dropout.FProp(theta.dropout, xs)
            for i in range(1, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, _ = layer.FProp(theta.rnn[i], xs, ps)
                ys = self.dropout.FProp(theta.dropout, ys)
                if hasattr(layer.params, 'cell'):
                    layer_params = layer.params.cell
                else:
                    layer_params = layer.params
                if layer_params.num_input_nodes == layer_params.num_output_nodes:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    # When cc_schedule is specified, make sure lstm_tpl is
                    # QuantizedLSTMCell with the same cc_schedule so that the RNN layer
                    # output is within clipping range.
                    xs = ys
            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Ejemplo n.º 7
0
    def FProp(self, theta, inputs, paddings):
        """Apply convolution to inputs.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor. It is expected to be of shape [batch, time,
        frequency, channel]. The time dimension corresponds to the height
        dimension as in images and the frequency dimension corresponds to the
        width dimension as in images.
      paddings: The paddings tensor, expected to be of shape [batch, time].

    Returns:
      outputs, out_paddings pair.
    """
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]),
                py_utils.assert_shape_match(
                    tf.shape(inputs),
                    tf.concat([
                        tf.shape(paddings),
                        [-1, symbolic.ToStatic(self.input_channels)]
                    ], 0))
            ], inputs)

            def _ApplyPadding(tensor_in, padding_in):
                padding_expanded = tf.expand_dims(
                    tf.expand_dims(padding_in, -1), -1)
                return tensor_in * (1.0 - padding_expanded)

            # Zeroing out padded inputs.
            inputs = _ApplyPadding(inputs, paddings)

            # Apply conv on 'inputs'.
            out = self._ApplyConv(theta, inputs)

            if p.partial_conv:
                out = self._RescaleBoundary(out, paddings)
            # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1.
            # But there's likely no real problems. Trying to set it gives an error:
            # pooling with SAME padding is not implemented for dilation_rate > 1.
            # NOTE: we use window=p.filter_stride[0] to be compatible with legacy
            # implementation.  Consider updating it to be the actual shape.
            conv_padding = ComputeConvOutputPadding(paddings,
                                                    window=p.filter_stride[0],
                                                    stride=p.filter_stride[0])
            # Assuming padded nodes will be properly zero-ed out if necessary by
            # sub-sequent layers.
            # out = _ApplyPadding(out, conv_padding)
            out = py_utils.HasShape(
                out, symbolic.ToStatic(self.OutShape(tf.shape(inputs))))
            return out, conv_padding
    def _AugmentationNetwork(self,
                             series_length,
                             inputs,
                             paddings,
                             global_seed,
                             domain_id_index=0):
        """Returns augmented features.

    Args:
      series_length: Total length of time series.
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      paddings: Batch of padding vectors of shape (batch_size, time_length).
      global_seed: an integer seed tensor for stateless random ops.
      domain_id_index: domain id index.

    Returns:
      Batch of output features of shape (batch_size, time_length, num_freq,
      channels) obtained by applying random augmentations to inputs.
    """
        p = self.params
        dtype = p.dtype

        # Unstack the features.
        if p.unstack:
            inputs, paddings = self.UnstackFeatures(inputs, paddings)

        lengths = tf.reduce_sum(1 - paddings, 1)
        inputs = self._TimeWarp(inputs,
                                lengths,
                                global_seed=global_seed,
                                dtype=dtype,
                                domain_id_index=domain_id_index)
        inputs = self._TimeMask(inputs,
                                lengths,
                                global_seed=global_seed,
                                noisify=p.use_noise,
                                gaussian_noise=p.gaussian_noise,
                                dtype=dtype,
                                domain_id_index=domain_id_index)
        inputs = self._FrequencyMask(inputs,
                                     global_seed=global_seed,
                                     dtype=dtype,
                                     domain_id_index=domain_id_index)

        # Restack the features after applying specaugment.
        if p.unstack:
            inputs = tf.reshape(
                inputs,
                [tf.shape(inputs)[0], series_length, -1,
                 tf.shape(inputs)[3]])

        return inputs
Ejemplo n.º 9
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Ejemplo n.º 10
0
    def _DerivePaddingsAndIds(src_ids, tgt_labels):
      """tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended."""
      tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0)
      src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32)
      tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32)
      tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32)

      bucket_key = tf.cast(
          tf.maximum(
              tf.reduce_sum(1.0 - src_paddings),
              tf.reduce_sum(1.0 - tgt_paddings)), tf.int32)

      return src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key
Ejemplo n.º 11
0
    def FProp(self, theta, input_batch, state0=None):
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Reshape to [t, b]
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)

            # Setup streaming states.
            if not state0:
                state0 = self.zero_state(theta, tf.shape(inputs)[1])
            state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers)

            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, state1.rnn[i] = layer.FProp(theta.rnn[i],
                                                xs,
                                                ps,
                                                state0=state0.rnn[i])
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id,
                                      state=state1)
Ejemplo n.º 12
0
    def InitBeamSearchStateCallback(theta, encoder_outputs, num_hyps_per_beam):
      """Wrapper for adding bias to _InitBeamSearchStateCallback.

      Exapnds state to track consistency of hypothesis with provided target.

      Args:
        theta: A NestedMap object containing weights' values of this layer and
          its children layers.
        encoder_outputs: A NestedMap computed by encoder.
        num_hyps_per_beam: An int, number hyps to keep for source sentence.

      Returns:
        initial_results: a `.NestedMap` of initial results.
        states: a `.NestedMap` of initial model states that the client
          would like to keep track of for each hyp. The states relevant here
          are:
          time_step: A scalar indicating current step (=0 for initial state) of
            decoder.  Must be provided and maintained by super.
          consistent: A boolean tensor of shape [tgt_batch, 1] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
      initial_results, states = self._InitBeamSearchStateCallback(
          theta, encoder_outputs, num_hyps_per_beam)
      assert hasattr(states, 'time_step')
      num_hyps = tf.shape(encoder_outputs.padding)[1] * num_hyps_per_beam
      # states.consistent is initially all True
      states.consistent = tf.ones([
          num_hyps,
      ], dtype=tf.bool)
      return initial_results, states
Ejemplo n.º 13
0
  def Finalize(self):
    """Finishes creation of the overall figure, returning the image summary."""
    subplot_grid_shape = self._subplot_grid_shape
    if subplot_grid_shape is None:
      subplot_grid_shape = (len(self._subplots), 1)

    # AddMatplotlibFigureSummary (due to restrictions of py_func) only supports
    # flattened list of tensors so we must do some bookkeeping to maintain a
    # mapping from _SubplotMetadata object to flattened_tensors.
    subplot_slices = []
    flattened_tensors = []
    for subplot in self._subplots:
      start = len(flattened_tensors)
      subplot_slices.append((start, start + len(subplot.tensor_list)))
      flattened_tensors.extend(subplot.tensor_list)

    def PlotFunc(fig, *numpy_data_list):
      gs = gridspec.GridSpec(*subplot_grid_shape, **self._gridspec_kwargs)
      for n, subplot in enumerate(self._subplots):
        axes = fig.add_subplot(gs[n])
        start, end = subplot_slices[n]
        subplot_data = numpy_data_list[start:end]
        subplot.plot_func(fig, axes, *subplot_data)

    func = functools.partial(_RenderMatplotlibFigures, self._figsize,
                             self._max_outputs, PlotFunc)
    batch_sizes = [tf.shape(t)[0] for t in flattened_tensors]
    num_tensors = len(flattened_tensors)
    with tf.control_dependencies([
        tf.assert_equal(
            batch_sizes, [batch_sizes[0]] * num_tensors, summarize=num_tensors)
    ]):
      rendered = tf.py_func(
          func, flattened_tensors, tf.uint8, name='RenderMatplotlibFigures')
    return tf.summary.image(self._name, rendered, max_outputs=self._max_outputs)
Ejemplo n.º 14
0
  def FProp(self, theta, inputs, paddings=None):
    """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
    p = self.params
    if paddings is None:
      paddings = self._GetDefaultPaddings(inputs)
    with tf.name_scope(p.name):
      norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
          theta, inputs, paddings)
      with tf.control_dependencies([
          py_utils.assert_greater_equal(norm_variance,
                                        tf.zeros_like(norm_variance)),
          py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                      tf.shape(norm_mean)),
          py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                      tf.shape(norm_variance)),
      ]):
        if p.use_fused_batch_norm_for_eval and self.do_eval:
          bn_output, _, _ = nn.fused_batch_norm(
              inputs,
              gamma,
              beta,
              norm_mean,
              norm_variance,
              self._epsilon,
              is_training=False)
        else:
          bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                norm_variance, beta, gamma,
                                                self._epsilon)

        if p.set_padded_output_to_zero:
          bn_output *= 1.0 - paddings

      return bn_output
    def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                          hyp_ids, hyp_lens, done_hyps, other_states,
                          pre_beam_search_step_callback,
                          post_beam_search_step_callback):
        """Extend greedy search hyps for one step.

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len].
      hyp_lens: Valid length of all the hyps. Tokens after eos ids are not
        counted.
      done_hyps: Whether or not a hyp has finished.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next greedy search step,
      (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states)
    """
        p = self.params
        # Increment hyp_lens by 1 if the hyp is not finished yet.
        hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32))

        bs_results, new_other_states = pre_beam_search_step_callback(
            theta, encoder_outputs, step_ids, other_states,
            1)  # num_hyps_per_beam
        new_step_ids = tf.math.argmax(bs_results.log_probs, 1)
        new_step_ids = tf.cast(new_step_ids, tf.int32)
        new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids))
        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        # Stash new_step_ids into the right slot.
        new_step_ids_1d = tf.reshape(new_step_ids, [-1])
        hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step,
                                                   new_step_ids_1d)
        # Update done_hyps if the current step_ids is the end of sequence token.
        done_hyps = tf.math.logical_or(
            done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id))

        return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                final_other_states)
Ejemplo n.º 16
0
  def FProp(self, theta, inputs):
    """Applies batch normalization.

    Using the implementation in github.com/
    tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550

    Args:
      theta: A nested map object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
    p = self.params
    inputs_dtype = inputs.dtype
    inputs = tf.cast(inputs, p.dtype)
    inputs = py_utils.with_dependencies([
        py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(
            theta.beta))
    ], inputs)
    with tf.name_scope(p.name) as scope:
      if self.do_eval:
        outputs = tf.nn.batch_normalization(inputs, theta.moving_mean,
                                            theta.moving_variance, theta.beta,
                                            theta.gamma, p.epsilon)
      else:
        mean, variance = self._Moments(inputs, p.bn_group_size)
        mean = py_utils.CheckNumerics(
            mean, 'mean of {} failed numeric check'.format(scope))
        variance = py_utils.CheckNumerics(
            variance, 'variance of {} failed numeric check'.format(scope))
        outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta,
                                            theta.gamma, p.epsilon)
      outputs.set_shape(inputs.get_shape())
      return tf.cast(outputs, inputs_dtype)
Ejemplo n.º 17
0
        def ReOrderHyps(x_in):
            """Reorders x_in based on prev hyp ids."""
            if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0:
                # For rank > 1 tensors we make use of an efficient matmul based gather
                # on tpu that takes in account the range of the values. For R1, we
                # rely on the tf.gather and xla to optimize it efficiently for R1
                # layout.
                if x_in.shape.ndims > 1:
                    if p.batch_major_state:
                        num_hyps = tf.shape(old_hyp_ids)[0]
                        x_out = beam_search_tpu_ops.fast_gather(
                            x_in,
                            old_hyp_ids,
                            num_hyps,
                            max_value=None,
                            batch_major_state=p.batch_major_state)
                    else:
                        # Use corrected indices only here for batch major compute as
                        # key/value caches are the states being affected.
                        correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                               if p.batch_major_compute else
                                               old_hyp_ids)

                        def _GatherStep(x_in, t):
                            """Gather for one time step.

              Args:
                x_in: in the shape of [T, B, ...] we first get slice(t) from the
                  tensors, then gather old_hyp_ids from the slice and write the
                  interpolated slice inplace to update the original x_in.
                t: current time step

              Returns:
                Updated x_in and time step
              """
                            x = tf.gather(tf.gather(x_in, t),
                                          correct_old_hyp_ids)
                            return inplace_ops.alias_inplace_update(
                                x_in, t, x), t + 1

                        x_out, _ = tf.while_loop(
                            lambda _, t: t <= cur_step, _GatherStep,
                            (x_in, tf.zeros([], tf.int32)))
                else:
                    x_out = tf.gather(x_in, old_hyp_ids)
                x_out.set_shape(x_in.get_shape())
                return x_out
            else:
                return x_in
def SplitTensors(xs, num_splits):
    """Splits tensors in `xs` evenly into num_splits along the 1st dimenion.

  Args:
    xs: A tuple of tensors. Each tensor's 1st dimension is the same size.
    num_splits: A python integer.

  Returns:
    A tuple of lists of tensors, num elements in the tuple = len(xs).

    i-th element in each list corresponds to i-th split of each tensor in xs
    along the first dimension of each tensor.
  """
    # assert first dim of all tensors in xs is equal
    batch_dims = [tf.shape(x)[0] for x in xs]
    all_batch_dims = tf.stack(batch_dims)

    all_batch_dims = py_utils.with_dependencies([
        py_utils.assert_equal(all_batch_dims,
                              tf.shape(xs[0])[0],
                              message='first dim of tensors in xs must match'),
        py_utils.assert_greater_equal(
            tf.shape(xs[0])[0],
            num_splits,
            message='first dim of tensors in xs must be greater than num_splits'
        )
    ], all_batch_dims)

    splits = ComputeSplits(tf.shape(xs[0])[0], num_splits)
    # add the above assertion into the compute graph
    splits = py_utils.with_dependencies([all_batch_dims], splits)
    split_xs = [
        tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs
    ]

    return split_xs
Ejemplo n.º 19
0
def SequenceLength(padding):
  """Computes the length of a sequence based on binary padding.

  Args:
    padding: A tensor of binary paddings shaped [batch, seqlen].

  Returns:
    seq_lens, A tensor of shape [batch] containing the non-padded length of each
      element of plot_tensor along the batch dimension.
  """
  seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32)
  # Get rid of any extra dimensions.
  batch_size = tf.shape(padding)[0]
  seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens')
  return seq_lens
Ejemplo n.º 20
0
      def ApplyBias():
        """Bias and update log_probs and consistent."""

        def TileForBeamAndFlatten(tensor):
          tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
          tensor = tf.tile(
              tensor, [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
          tgt_batch = tf.shape(step_ids)[0]  # num_hyps_per_beam*src_batch
          return tf.reshape(tensor, [tgt_batch])

        # Consistent if step_ids == labels from previous step
        # TODO(navari): Consider updating consistent only if weights > 0. Then
        # re-evaluate the need for bias_only_if_consistent=True.
        # Note that prev_label is incorrrect for step 0 but is overridden later
        prev_label = TileForBeamAndFlatten(
            tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
        is_step0 = tf.equal(time_step, 0)
        local_consistence = tf.math.logical_or(
            is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
        consistent = tf.math.logical_and(states.consistent, local_consistence)

        # get label, weight slices corresponding to current time_step
        label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
        weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1))
        if p.bias_only_if_consistent:
          weight = weight * tf.cast(consistent, p.dtype)

        # convert from dense label to sparse label probs
        vocab_size = tf.shape(bs_results.log_probs)[1]
        uncertainty = tf.constant(
            1e-10, p.dtype)  # avoid 0 probs which may cause issues with log
        label_probs = tf.one_hot(
            label,
            vocab_size,
            on_value=1 - uncertainty,
            off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
            dtype=p.dtype)  # [tgt_batch, vocab_size]
        pred_probs = tf.exp(bs_results.log_probs)

        # interpolate predicted probs and label probs
        weight = tf.expand_dims(weight, 1)
        probs = py_utils.with_dependencies([
            py_utils.assert_less_equal(weight, 1.),
            py_utils.assert_greater_equal(weight, 0.)
        ], (1.0 - weight) * pred_probs + weight * label_probs)
        return tf.math.log(probs), consistent
    def _FrequencyMask(self,
                       inputs,
                       global_seed,
                       dtype=tf.float32,
                       domain_id_index=0):
        """Applies frequency masking with given degree to inputs.

    Args:
      inputs: Batch of input features of shape (batch_size, time_length,
        num_freq, channels).
      global_seed: an integer seed tensor for stateless random ops.
      dtype: Data type.
      domain_id_index: domain id index.

    Returns:
      Inputs with random frequency masking applied.
    """
        p = self.params

        # Mask parameters.
        freq_mask_max_bins = p.freq_mask_max_bins[domain_id_index]
        multiplicity = p.freq_mask_count[domain_id_index]

        # If masking length or count is zero, do nothing.
        if freq_mask_max_bins == 0 or multiplicity == 0:
            return inputs

        # Arguments to pass to mask generator.
        batch_size, _, num_freq, _ = py_utils.GetShape(inputs)
        choose_range = tf.cast(tf.broadcast_to(num_freq, (batch_size, )),
                               dtype=tf.int32)
        # Create masks in frequency direction and apply.
        block_arrays = self._GetMask(tf.shape(inputs)[0],
                                     choose_range=choose_range,
                                     mask_size=num_freq,
                                     global_seed=global_seed,
                                     max_length=freq_mask_max_bins,
                                     masks_per_frame=0.0,
                                     multiplicity=multiplicity,
                                     dtype=dtype,
                                     max_ratio=1.0)
        return self.EinsumBxycByBxyc(inputs, block_arrays)
Ejemplo n.º 22
0
    def FProp(self, theta, inputs):
        """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
        p = self.params
        with tf.name_scope(p.name):
            computation_cost.Add(
                self, 'flops',
                tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
                tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims),
                        tf.int64) * 2)
            return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims,
                                           p.output_dims)
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.math.logical_and(
                cur_step < max_steps,
                tf.math.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
    """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
    source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
    value_dict = {}
    for output in beam_search_outputs:
        hyps_per_beam = py_utils.with_dependencies([
            py_utils.assert_equal(source_batch,
                                  tf.shape(output.topk_hyps)[0]),
        ],
                                                   tf.shape(
                                                       output.topk_hyps)[1])
        for k, v in six.iteritems(output._asdict()):
            if v is None:
                continue
            if k == 'done_hyps':
                v = tf.transpose(v)
            if k not in value_dict:
                value_dict[k] = []
            value_dict[k].append(
                tf.reshape(v, [source_batch, hyps_per_beam, -1]))

    # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
    concatenated = {}
    for k, values in six.iteritems(value_dict):
        if len(values) != len(beam_search_outputs):
            raise ValueError('Incomplete values for %s: %s' %
                             (k, beam_search_outputs))
        concatenated[k] = tf.concat(values, axis=1)

    scores = concatenated['topk_scores']
    scores = tf.where(tf.equal(concatenated['topk_lens'], 0),
                      tf.fill(tf.shape(scores), -1e6), scores)
    scores = tf.squeeze(scores, -1)

    # Select top max_hyps_per_beam indices per beam.
    _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
    batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1),
                        [1, max_hyps_per_beam])
    # [source_batch, max_hyps_per_beam, 2]
    gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

    # Gather the merged top hyps according to 'gather_indices'.
    top = beam_search_outputs[0]._asdict()
    total_hyps = source_batch * max_hyps_per_beam
    for k, v in six.iteritems(concatenated):
        v = tf.gather_nd(v, gather_indices)
        if k == 'done_hyps':
            v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
        elif k == 'topk_hyps':
            v = tf.reshape(v, [source_batch, max_hyps_per_beam])
        elif k == 'topk_ids':
            v = tf.reshape(v, [total_hyps, -1])
        elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
            v = tf.reshape(v, [total_hyps])
        else:
            raise ValueError('Unexpected field: %s' % k)
        top[k] = v
    return BeamSearchDecodeOutput(**top)
    def BeamSearchDecode(self,
                         theta,
                         encoder_outputs,
                         num_hyps_per_beam_override=0,
                         init_beam_search_state=None,
                         pre_beam_search_step_callback=None,
                         post_beam_search_step_callback=None,
                         max_steps=None):
        """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks. Mostly opaque to BeamSearchHelper, except that it should
        contain either a 'seq_lengths' field of shape [source_batch_size] or
        a 'paddings' field of shape [source_max_lengths, source_batch_size].
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
        p = self.params
        num_hyps_per_beam = p.num_hyps_per_beam
        if num_hyps_per_beam_override > 0:
            num_hyps_per_beam = num_hyps_per_beam_override
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta, encoder_outputs, num_hyps_per_beam)

        num_hyps = tf.shape(initial_results.log_probs)[0]
        num_beams = num_hyps // num_hyps_per_beam

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        min_score = -1e36
        best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
        cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
        in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
        in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
        in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
        bs_atten_probs = tf.zeros(
            [max_steps, num_hyps,
             tf.shape(initial_results.atten_probs)[1]],
            dtype=p.dtype)
        cur_step = tf.constant(0, dtype=tf.int32)
        all_done = tf.constant(False, dtype=tf.bool)
        core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                          in_prev_hyps, in_done_hyps, bs_atten_probs)

        def LoopContinue(cur_step, all_done, unused_step_ids,
                         unused_core_bs_states, unused_other_states_list):
            return tf.math.logical_and(cur_step < max_steps,
                                       tf.math.logical_not(all_done))

        def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                     other_states_list):
            (cur_step, all_done, new_step_ids, new_bs_states,
             new_other_states) = self._BeamSearchStep(
                 theta, encoder_outputs, cur_step, step_ids, core_bs_states,
                 other_states.Pack(other_states_list), num_hyps_per_beam,
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, all_done, new_step_ids, new_bs_states,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(all_done.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              _GetShapes(core_bs_states),
                              _GetShapes(flat_other_states, none_shapes=True)))
        # [target_seq_len, num_beams * num_hyps_per_beam].
        final_done_hyps = final_bs_states[5]
        final_other_states = other_states.Pack(flat_final_other_states)

        # Assume that `paddings` has shape [source_max_lengths, source_batch_size]
        # by default, and compute `encoded_seq_lengths` accordingly. This can be
        # overridden by directly passing `seq_lengths` in the `encoder_outputs`
        # NestedMap.
        encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None)
        if encoded_seq_lengths is None:
            source_paddings = encoder_outputs.padding
            if isinstance(source_paddings, py_utils.NestedMap):
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(
                            1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
            else:
                encoded_seq_lengths = tf.cast(
                    tf.round(
                        tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
                    tf.int32)

        # [num_beams, num_hyps_per_beam].
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            encoded_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        # [num_beams * num_hyps_per_beam, ...].
        max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

        return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                      topk_lens, topk_scores, None,
                                      final_other_states)
    def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                        core_bs_states, other_states, num_hyps_per_beam,
                        pre_beam_search_step_callback,
                        post_beam_search_step_callback):
        """Extend beam search hyps for one step.

      | num_beams = Number of source sequences to be decoded.
      | num_hyps_per_beam = Number of hyps to keep per source sequence.
      | num_hyps = num_beams * num_hyps_per_beam
      | src_seq_len = Number of time steps in the source sequence.
      | src_batch = Number of examples in the source sequence.
      | tgt_seq_len = Maximum allowed time steps in the target sequence.
      | tgt_batch = num_hyps_per_beam * src_batch

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      core_bs_states: A tuple of core beam search states. This list is
        maintained by this helper class.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      num_hyps_per_beam: Num of hyps to keep per beam.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next beam search step,
      (next step, all_done, step_ids, core_bs_states, other_states)
    """
        p = self.params

        bs_results, other_states = pre_beam_search_step_callback(
            theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam)

        (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps,
         in_done_hyps, in_atten_probs) = core_bs_states

        (out_best_scores, out_cumulative_scores, out_scores, out_hyps,
         out_prev_hyps, out_done_hyps, out_atten_probs,
         all_done) = ops.beam_search_step(
             tf.cast(bs_results.log_probs, dtype=p.dtype),
             tf.cast(bs_results.atten_probs, dtype=p.dtype),
             best_scores,
             cumulative_scores,
             in_scores,
             in_hyps,
             in_prev_hyps,
             in_done_hyps,
             in_atten_probs,
             bs_results.is_last_chunk if self._model_uses_eoc_id else [],
             cur_step,
             eoc_id=p.target_eoc_id,
             eos_id=p.target_eos_id,
             beam_size=p.beam_size,
             num_hyps_per_beam=num_hyps_per_beam,
             valid_eos_max_logit_delta=p.valid_eos_max_logit_delta,
             merge_paths=p.merge_paths,
             allow_empty_terminated_hyp=p.allow_empty_terminated_hyp,
             ensure_full_beam=p.ensure_full_beam,
             force_eos_in_last_step=p.force_eos_in_last_step,
             local_eos_threshold=p.local_eos_threshold)

        new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids))
        new_step_ids.set_shape(step_ids.get_shape())

        old_hyp_ids = tf.reshape(
            tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1])

        if p.batch_major_compute:
            # Transformed the indices into the key/value cache for fast decoding
            # (prefix_states in other_states) due to the num_hyps dimension of
            # cache is computed as num_beams by num_hyps_per_beam, which is different
            # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams).
            # Both transpose and recomputation are required to correct the indices.
            num_beams = tf.shape(best_scores)[0]
            old_hyp_ids_in_cache_order = tf.reshape(
                tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])),
                [-1])
            old_hyp_ids_in_cache_order = (
                (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam +
                old_hyp_ids_in_cache_order // num_beams)

        new_bs_states = (out_best_scores, out_cumulative_scores, out_scores,
                         out_hyps, out_prev_hyps, out_done_hyps,
                         out_atten_probs)

        def ReOrderHyps(x_in):
            """Reorders x_in based on prev hyp ids."""
            if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims
                    and x_in.shape.ndims > 0):
                if x_in.shape.ndims > 2 and not p.batch_major_state:
                    # Use corrected indices only here for batch major compute as key/value
                    # caches are the states being affected.
                    correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                                           if p.batch_major_compute else
                                           old_hyp_ids)
                    x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
                else:
                    x_out = tf.gather(x_in, old_hyp_ids)
                x_out.set_shape(x_in.get_shape())
                return x_out
            else:
                return x_in

        new_other_states = other_states.Transform(ReOrderHyps)

        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        return (cur_step + 1, all_done, new_step_ids, new_bs_states,
                final_other_states)
Ejemplo n.º 27
0
 def TileForBeamAndFlatten(tensor):
   tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
   tensor = tf.tile(
       tensor, [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
   tgt_batch = tf.shape(step_ids)[0]  # num_hyps_per_beam*src_batch
   return tf.reshape(tensor, [tgt_batch])
Ejemplo n.º 28
0
 def PadToTargetSeqLen(tensor, constant):
   length = tf.shape(tensor)[1]
   pad = tf.maximum(0, p.beam_search.target_seq_len - length)
   return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
Ejemplo n.º 29
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Ejemplo n.º 30
0
    def BeamSearchDecodePostProcess(self, num_hyps_per_beam, max_steps,
                                    r1_shape, r2_shape, r3_shape, hyps,
                                    prev_hyps, done_hyps, scores, atten_probs,
                                    eos_scores, eos_atten_probs,
                                    source_seq_lengths,
                                    *flat_final_other_states):
        """Beam search post processing functions on CPUs.


    Args:
      num_hyps_per_beam: Number of hyps per beam.
      max_steps: Maximum number of beam search steps.
      r1_shape: A tensor of shape [1] with value [time].
      r2_shape: A tensor of shape [2] with values [time, b * k].
      r3_shape: A tensor of shape [3] with values [time, b * k, seq_len].
      hyps: A tensor of shape [1] with ids of the token selected.
      prev_hyps: A tensor of shape [time * b * k] with index to the previous
        hyps which was selected.
      done_hyps: A boolean tensor of shape [time * b * k] where value
        indicates if hyps was terminated.
      scores: A tensor of shape [time * b * k] with scores of the token
        selected.
      atten_probs: A tensor of shape [time * b * k, seq_len] which contain the
        attention probabilities over the source words against word in the
        previous hyps.
      eos_scores: A tensor of shape [time * b * k] with scores of the eos token
        selected.
      eos_atten_probs: A tensor of shape [time * b * k, seq_len] which contain
        the attention probabilities over the source words against word in the
        previous hyps.
      source_seq_lengths:  A tensor of shape [time] containing the source
        seq_lengths.
      *flat_final_other_states: A array of tensors that are part of other
        states.

    Returns:
      final_done_hyps: A tensor of shape [time, b * k] containing `Hypothesis`
        pbs containing terminated hyps.
      topk_hyps, topk_ids, topk_lens, topk_scores: Top K terminated Hyps.
      flat_final_other_states: A array of tensors that are part of other states.
    """
        p = self.params

        def _ReshapeBackToHigherRank(inps, r_shape):
            for i in range(len(inps)):
                inps[i] = tf.reshape(inps[i], r_shape)
            return inps

        # Reshape all tensors back to original shapes of rank 1, 2 and 3.
        r1_inps = [source_seq_lengths]
        r1_inps = _ReshapeBackToHigherRank(r1_inps, r1_shape)
        r2_inps = [hyps, prev_hyps, done_hyps, scores, eos_scores]
        r2_inps = _ReshapeBackToHigherRank(r2_inps, r2_shape)
        r3_inps = [atten_probs, eos_atten_probs]
        r3_inps = _ReshapeBackToHigherRank(r3_inps, r3_shape)

        (source_seq_lengths, hyps, prev_hyps, done_hyps, scores, eos_scores,
         atten_probs, eos_atten_probs) = (r1_inps + r2_inps + r3_inps)

        final_done_hyps = ops.hyps_from_beam_search_outs(
            hyps,
            prev_hyps,
            done_hyps,
            scores,
            atten_probs,
            eos_scores,
            eos_atten_probs,
            eos_id=p.target_eos_id,
            num_hyps_per_beam=num_hyps_per_beam)
        topk_hyps = ops.top_k_terminated_hyps(
            final_done_hyps,
            source_seq_lengths,
            k=num_hyps_per_beam,
            num_hyps_per_beam=num_hyps_per_beam,
            length_normalization=p.length_normalization,
            coverage_penalty=p.coverage_penalty,
            target_seq_length_ratio=p.target_seq_length_ratio,
            eoc_id=p.target_eoc_id,
            merge_paths=p.merge_paths)
        topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
            topk_hyps, max_seq_length=max_steps)
        # [num_beams, num_hyps_per_beam].
        topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))
        return (final_done_hyps, topk_hyps, topk_ids, topk_lens,
                topk_scores) + tuple(flat_final_other_states)