Esempio n. 1
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Esempio n. 2
0
def ComputeConvOutputPadding(paddings, window, stride,
                             padding_algorithm='SAME'):
  """Computes paddings for convolution and pooling output.

  out_padding[i] == 1 iff any in_padding corresponding to that output is 1.

  Args:
    paddings: The paddings tensor. It is expected to be of shape [batch, time].
    window: The size of the windows.
    stride: The time-stride between adjacent windows.
    padding_algorithm: 'SAME' or 'VALID'.

  Returns:
    out_padding, The new padding tensor of size [batch, ceil(time / stride)].
  """
  if stride == 1:
    return paddings

  # Pad so input_length divides stride.
  input_length = py_utils.GetShape(paddings)[1]
  pad_len = (input_length + stride - 1) // stride * stride - input_length
  paddings = tf.pad(paddings, [[0, 0], [0, pad_len]], constant_values=1.0)
  out_padding = tf.nn.pool(
      tf.expand_dims(paddings, -1),
      [window],
      'MAX',
      padding=padding_algorithm,
      strides=[stride],
  )
  return tf.squeeze(out_padding, -1)
Esempio n. 3
0
    def FProp(self, theta, input_batch):
        """Encodes source as represented by `inputs` and `paddings`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` with fields:
        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:

      - encoded: The encoded features, a tensor of shape [time, batch, depth]
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
    """
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Now the rnn layers.
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            self._emb_out = xs
            ps = paddings
            # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell
            # with the same cc_schedule so that the RNN layer output is within
            # clipping range.
            xs = self.rnn[0].FProp(theta.rnn[0], xs, ps)
            xs = self.dropout.FProp(theta.dropout, xs)
            for i in range(1, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, _ = layer.FProp(theta.rnn[i], xs, ps)
                ys = self.dropout.FProp(theta.dropout, ys)
                if hasattr(layer.params, 'cell'):
                    layer_params = layer.params.cell
                else:
                    layer_params = layer.params
                if layer_params.num_input_nodes == layer_params.num_output_nodes:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    # When cc_schedule is specified, make sure lstm_tpl is
                    # QuantizedLSTMCell with the same cc_schedule so that the RNN layer
                    # output is within clipping range.
                    xs = ys
            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
Esempio n. 4
0
 def _ReadRecordSentencePairProto(self, record):
     """Reads the input record as a binary SentencePair proto."""
     # We defer handling the `lang` field in the proto until TextPackedInput
     # figures out how to handle lang_ids. For now `lang` fields are ignored.
     _, sentence_protos = tf.io.decode_proto(
         bytes=record,
         message_type='tensorflow.babelfish.SentencePair',
         field_names=['src_sentence', 'tgt_sentence'],
         output_types=[tf.string, tf.string],
         descriptor_source=_GetDescriptorSetForTextInput())
     sentence_protos = tf.squeeze(sentence_protos)
     _, sentences = tf.io.decode_proto(
         bytes=sentence_protos,
         message_type='tensorflow.babelfish.Sentence',
         field_names=['sentence'],
         output_types=[tf.string],
         descriptor_source=_GetDescriptorSetForTextInput())
     sentences = tf.squeeze(sentences)
     return sentences[0], sentences[1]
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.math.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.math.logical_and(states.consistent,
                                                 local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.math.log(probs), consistent
Esempio n. 6
0
    def FProp(self, theta, input_batch, state0=None):
        p = self.params
        src_segment_id = None
        with tf.name_scope(p.name):
            # Reshape to [t, b]
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)

            # Setup streaming states.
            if not state0:
                state0 = self.zero_state(theta, tf.shape(inputs)[1])
            state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers)

            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys, state1.rnn[i] = layer.FProp(theta.rnn[i],
                                                xs,
                                                ps,
                                                state0=state0.rnn[i])
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id,
                                      state=state1)
Esempio n. 7
0
 def _Slice(tensor):
     """Return a slice of this tensor at time=state0.t."""
     shape = py_utils.GetShape(tensor)
     # All zeros except for t in the time dimension.
     # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
     begin = tf.one_hot(self.params.axis,
                        tf.rank(tensor),
                        on_value=state0.t)
     # Same as shape, but with a 1 in the time dimension.
     # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
     size = tf.concat([
         shape[0:self.params.axis],
         tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
     ],
                      axis=0)
     # Make a slice where the time dimension is fixed at state0.t.
     time_slice = tf.slice(tensor, begin, size)
     # Remove the time dimension.
     return tf.squeeze(time_slice, axis=self.params.axis)
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs):
    """Merges beam search hyps from multiple decoders.

  Args:
    max_hyps_per_beam: the number of top hyps in the merged results. Must be
      less than or equal to total number of input hyps.
    beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share
      the same source_batch and max sequence length.

  Returns:
    A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per
    beam.
  """
    source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0]
    value_dict = {}
    for output in beam_search_outputs:
        hyps_per_beam = py_utils.with_dependencies([
            py_utils.assert_equal(source_batch,
                                  tf.shape(output.topk_hyps)[0]),
        ],
                                                   tf.shape(
                                                       output.topk_hyps)[1])
        for k, v in six.iteritems(output._asdict()):
            if v is None:
                continue
            if k == 'done_hyps':
                v = tf.transpose(v)
            if k not in value_dict:
                value_dict[k] = []
            value_dict[k].append(
                tf.reshape(v, [source_batch, hyps_per_beam, -1]))

    # Concatenate the tensors along the 'num_hyps_per_beam' dimension.
    concatenated = {}
    for k, values in six.iteritems(value_dict):
        if len(values) != len(beam_search_outputs):
            raise ValueError('Incomplete values for %s: %s' %
                             (k, beam_search_outputs))
        concatenated[k] = tf.concat(values, axis=1)

    scores = concatenated['topk_scores']
    scores = tf.where(tf.equal(concatenated['topk_lens'], 0),
                      tf.fill(tf.shape(scores), -1e6), scores)
    scores = tf.squeeze(scores, -1)

    # Select top max_hyps_per_beam indices per beam.
    _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam)
    batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1),
                        [1, max_hyps_per_beam])
    # [source_batch, max_hyps_per_beam, 2]
    gather_indices = tf.stack([batch_ids, top_indices], axis=-1)

    # Gather the merged top hyps according to 'gather_indices'.
    top = beam_search_outputs[0]._asdict()
    total_hyps = source_batch * max_hyps_per_beam
    for k, v in six.iteritems(concatenated):
        v = tf.gather_nd(v, gather_indices)
        if k == 'done_hyps':
            v = tf.transpose(tf.reshape(v, [total_hyps, -1]))
        elif k == 'topk_hyps':
            v = tf.reshape(v, [source_batch, max_hyps_per_beam])
        elif k == 'topk_ids':
            v = tf.reshape(v, [total_hyps, -1])
        elif k in ('topk_lens', 'topk_scores', 'topk_decoded'):
            v = tf.reshape(v, [total_hyps])
        else:
            raise ValueError('Unexpected field: %s' % k)
        top[k] = v
    return BeamSearchDecodeOutput(**top)