Beispiel #1
0
    def _InitBeamSearchStateCallback(self, theta, encoder_outputs,
                                     num_hyps_per_beam):
        """Returns initial beams search states.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      encoder_outputs: a NestedMap computed by encoder.
      num_hyps_per_beam: An int, number hyps to keep for source sentence.
    Returns:
      A tuple (initial_results, states).
        initial_results: a `.NestedMap` of initial results.
          atten_probs:
            The initial attention probs, of shape [tgt_batch, src_len].
        states: a `.NestedMap` of initial model states.
          source_encs:
            A tensor of shape [src_batch, src_len, source_dim].
          source_paddings:
            A tensor of shape [src_batch, src_len].
          target_ids:
            Initial empty list of decoded ids. [num_hyps, 0].
    """
        p = self.params

        source_encs = encoder_outputs.encoded
        num_hyps = py_utils.GetShape(source_encs)[1] * num_hyps_per_beam
        source_len = py_utils.GetShape(source_encs)[0]

        # Dummy attention probs
        atten_probs = tf.ones([num_hyps, source_len]) / tf.to_float(source_len)
        initial_results = py_utils.NestedMap(log_probs=tf.zeros(
            [num_hyps, p.softmax.num_classes], dtype=py_utils.FPropDtype(p)),
                                             atten_probs=atten_probs)

        batch_size = num_hyps
        atten_hidden_dim = p.trans_tpl.tr_atten_tpl.atten_hidden_dim
        if not atten_hidden_dim:
            atten_hidden_dim = p.model_dim

        if p.beam_search.name == 'tpu_beam_search':
            seq_len = p.target_seq_len
        else:
            seq_len = 0

        prefix_states = py_utils.NestedMap({
            'layer_%d' % layer: py_utils.NestedMap({
                'key':
                tf.zeros([seq_len, batch_size, atten_hidden_dim],
                         dtype=py_utils.FPropDtype(p)),
                'value':
                tf.zeros([seq_len, batch_size, atten_hidden_dim],
                         dtype=py_utils.FPropDtype(p)),
            })
            for layer in range(p.num_trans_layers)
        })

        return initial_results, py_utils.NestedMap({
            'prefix_states': prefix_states,
            'time_step': tf.constant(0)
        })
Beispiel #2
0
  def _InitBeamSearchStateCallback(self,
                                   theta,
                                   source_encs,
                                   source_paddings,
                                   num_hyps_per_beam,
                                   additional_source_info=None):
    """Returns initial beams search states.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_encs: A tensor of shape [src_len, src_batch, source_dim].
          Can be [time, batch, depth, num_layers] if is_transparent is set.
      source_paddings: A tensor of shape [src_len, src_batch].
      num_hyps_per_beam: An int, number hyps to keep for source sentence.
      additional_source_info: a `.NestedMap` of tensors containing extra context
          information about the source that may be useful for decoding.
    Returns:
      A tuple (initial_results, states).
        initial_results: a `.NestedMap` of initial results.
          atten_probs:
            The initial attention probs, of shape [tgt_batch, src_len].
        states: a `.NestedMap` of initial model states.
          source_encs:
            A tensor of shape [src_batch, src_len, source_dim].
          source_paddings:
            A tensor of shape [src_batch, src_len].
          target_ids:
            Initial empty list of decoded ids. [num_hyps, 0].
    """
    p = self.params
    # additional_source_info is currently not used.
    del additional_source_info

    num_hyps = py_utils.GetShape(source_encs)[1] * num_hyps_per_beam
    source_len = py_utils.GetShape(source_encs)[0]

    # Dummy attention probs
    atten_probs = tf.ones([num_hyps, source_len]) / tf.to_float(source_len)
    initial_results = py_utils.NestedMap({'atten_probs': atten_probs})

    batch_size = num_hyps
    key_channels = p.model_dim
    value_channels = p.model_dim

    prefix_states = py_utils.NestedMap({
        'layer_%d' % layer: py_utils.NestedMap({
            'key':
                tf.zeros([batch_size, 0, key_channels],
                         dtype=py_utils.FPropDtype(p)),
            'value':
                tf.zeros([batch_size, 0, value_channels],
                         dtype=py_utils.FPropDtype(p)),
        }) for layer in range(p.num_trans_layers)
    })

    return initial_results, py_utils.NestedMap({
        'prefix_states': prefix_states,
        'time_step': 0
    })
Beispiel #3
0
 def _ValidateBiases(self, content_bias, positional_bias, n, h):
   if content_bias is not None:
     content_bias = py_utils.HasShape(content_bias, [n, h])
   else:
     content_bias = tf.constant(0, dtype=py_utils.FPropDtype(self.params))
   if positional_bias is not None:
     positional_bias = py_utils.HasShape(positional_bias, [n, h])
   else:
     positional_bias = tf.constant(0, dtype=py_utils.FPropDtype(self.params))
   return content_bias, positional_bias
Beispiel #4
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.math.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.math.logical_and(states.consistent,
                                                 local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent,
                                              py_utils.FPropDtype(p))

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                label_probs = tf.one_hot(
                    label, vocab_size,
                    dtype=py_utils.FPropDtype(p))  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                # Ensure that tf.math.log is applied to positive values.
                probs = tf.maximum(probs, tf.constant(1e-12,
                                                      dtype=probs.dtype))
                return tf.math.log(probs), consistent
Beispiel #5
0
  def zero_state(self, batch_size):
    p = self.params
    num_groups = self.num_groups

    if not p.cumulative:
      return py_utils.NestedMap()

    # Note: Prefer storing data in <=4D tensors, as TFLite doesn't support
    # implicit broadcasting for 5D (or larger) tensors on many operators.
    cached_count_shape = [batch_size, 1]
    cached_moment_shape = [batch_size, num_groups]
    cached_sum = tf.zeros(cached_moment_shape, py_utils.FPropDtype(p))
    cached_count = tf.zeros(cached_count_shape, py_utils.FPropDtype(p))
    cached_var = tf.zeros(cached_moment_shape, py_utils.FPropDtype(p))
    return py_utils.NestedMap(
        cached_sum=cached_sum, cached_count=cached_count, cached_var=cached_var)
Beispiel #6
0
  def zero_state(self, batch_size):
    p = self.params
    num_groups = self.num_groups

    if not p.cumulative:
      return py_utils.NestedMap()

    if p.input_rank == 4:
      cache_shape = [batch_size, 1, 1, num_groups, 1]
    else:
      cache_shape = [batch_size, 1, num_groups, 1]
    cached_sum = tf.zeros(cache_shape, py_utils.FPropDtype(p))
    cached_count = tf.zeros(cache_shape, py_utils.FPropDtype(p))
    cached_var = tf.zeros(cache_shape, py_utils.FPropDtype(p))
    return py_utils.NestedMap(
        cached_sum=cached_sum, cached_count=cached_count, cached_var=cached_var)
Beispiel #7
0
 def _Cast(x):
     if x is None:
         return None
     x = tf.convert_to_tensor(x)
     if not x.dtype.is_floating:
         return x
     return tf.cast(x, py_utils.FPropDtype(self.params))
Beispiel #8
0
    def _config_outfeed(self, xformer, infeed_batch):
        """Setup the outfeed ops."""
        fprop_dtype = py_utils.FPropDtype(self.model_params.task)

        assert len(infeed_batch) == 6 or len(infeed_batch) == 7, len(
            infeed_batch)
        if len(infeed_batch) == 7:
            (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels, _,
             _) = infeed_batch
        elif len(infeed_batch) == 6:
            (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels,
             _) = infeed_batch
        tgt_segment_id = tf.cast(tgt_segment_id, fprop_dtype)

        input_batch = py_utils.NestedMap()
        input_batch.src = py_utils.NestedMap()
        input_batch.src.ids = (0 * tgt_ids)  # unused
        input_batch.src.segment_ids = (0 * tgt_segment_id)  # unused
        input_batch.src.segment_pos = (0 * tgt_segment_pos)  # unused
        input_batch.tgt = py_utils.NestedMap()
        input_batch.tgt.ids = tgt_ids
        input_batch.tgt.segment_ids = tgt_segment_id
        input_batch.tgt.segment_pos = tgt_segment_pos
        input_batch.tgt.labels = tgt_labels  # only used when --fprop=true

        with tpu_summary.context(rewrite_while_loop=True):
            dec_ret = xformer.DecodeIds(xformer.theta, input_batch)
            dec_metrics = tpu_summary.merge_all()
            key = infeed_batch[0]
            return [
                key, tgt_ids, tgt_segment_id, dec_ret.topk_ids,
                dec_ret.topk_lens, dec_ret.topk_scores, dec_metrics
            ]
Beispiel #9
0
    def testMoEFFLayerFProp(self, use_fflayer_start_moe, use_fflayer_end_moe,
                            expected_aux_loss):
        p = self._GetParams()
        if use_fflayer_start_moe:
            p.fflayer_start_tpl = gshard_builder.MoEBuilder.Params().Set(
                e_dim=2, c_dim=2, num_devices=2)
        if use_fflayer_end_moe:
            p.fflayer_end_tpl = gshard_builder.MoEBuilder.Params().Set(
                e_dim=2, c_dim=2, num_devices=2)
        l = p.Instantiate()
        inputs, paddings = self._GetInputs()
        inputs = tf.convert_to_tensor(inputs)
        paddings = tf.convert_to_tensor(paddings)
        in_nmap = py_utils.NestedMap(features=inputs, paddings=paddings)
        in_nmap.aux_loss = tf.convert_to_tensor(0., py_utils.FPropDtype(p))
        out_nmap = l.FPropDefaultTheta(in_nmap)
        self.assertIn('aux_loss', out_nmap)
        loss = tf.reduce_sum(out_nmap.features) + 0.01 * out_nmap.aux_loss
        grads = tf.gradients(
            loss,
            l.vars.Flatten(),
            unconnected_gradients=tf.UnconnectedGradients.ZERO)

        with self.session() as sess:
            tf.global_variables_initializer().run()
            out_vals = sess.run(out_nmap.features)
            grad_vals = sess.run(grads)
            self.assertEqual(out_nmap.aux_loss.shape, ())
            aux_loss = sess.run(out_nmap.aux_loss)
            self.assertAlmostEqual(expected_aux_loss, aux_loss, places=5)
            print([x.shape for x in out_vals])
            print([g.shape for g in grad_vals])
Beispiel #10
0
 def _InferenceSubgraph_Default(self):
     """Default inference subgraph."""
     batch_size = None
     seq_length = None
     fp_dtype = py_utils.FPropDtype(self.params)
     tshape = (batch_size, seq_length)
     input_ids = tf.placeholder(dtype=tf.int32, shape=tshape)
     targets = tf.placeholder(dtype=tf.int32, shape=tshape)
     paddings = tf.placeholder(dtype=fp_dtype, shape=tshape)
     weights = tf.placeholder(dtype=fp_dtype, shape=tshape)
     segment_ids = tf.placeholder(dtype=tf.int32, shape=tshape)
     segment_pos = tf.placeholder(dtype=tf.int32, shape=tshape)
     word_count = tf.placeholder(dtype=tf.int32, shape=(batch_size))
     num_sentences = tf.placeholder(dtype=tf.int32, shape=(batch_size))
     feeds = {
         'ids': input_ids,
         'labels': targets,
         'paddings': paddings,
         'weights': weights,
         'segment_ids': segment_ids,
         'segment_pos': segment_pos,
         'word_count': word_count,
         'num_sentences': num_sentences
     }
     input_batch = py_utils.NestedMap(feeds)
     loss, _ = self.FPropTower(self.theta, input_batch)
     fetches = {'loss': loss['loss'][0]}
     return fetches, feeds
Beispiel #11
0
    def testTransformerStackAlternateLayers(self):
        batch = 3
        tf.flags.FLAGS.tpu_compatible = True
        with self.session(use_gpu=False) as sess:
            model_dim = 2
            num_transformer_layers = 2
            transformer_tpl = layers_with_attention.TransformerLayer.Params()
            transformer_tpl.tr_atten_tpl.num_attention_heads = 1
            transformer_tpl.tr_fflayer_tpl.hidden_dim = 2

            params = mt_layers.TransformerStack.Params().Set(
                name='transformer',
                model_dim=model_dim,
                num_transformer_layers=num_transformer_layers,
                transformer_tpl=[
                    transformer_tpl.Copy()
                    for _ in range(num_transformer_layers)
                ],
                random_seed=123456)

            xformer = mt_layers.TransformerStack(params)
            input_arr = np.array([
                [[0, 1]] * batch,
                [[1, -1]] * batch,
            ],
                                 dtype=int)
            paddings_arr = np.array([[0] * batch, [0] * batch], dtype=int)
            inputs = tf.constant(input_arr.tolist(),
                                 dtype=py_utils.FPropDtype(params))
            paddings = tf.constant(paddings_arr.tolist(),
                                   dtype=py_utils.FPropDtype(params))
            output, _, _ = xformer.FProp(xformer.theta, inputs, paddings)

            tf.global_variables_initializer().run()
            output = sess.run(output)
            print(repr(output))
            # pylint: disable=bad-whitespace
            # pyformat: disable
            self.assertAllCloseAccordingToType(
                np.array([[[-2.17566538, -0.2821945],
                           [-2.17566514, -0.28219438],
                           [-2.17566514, -0.28219438]],
                          [[-0.71516591, -0.90594757],
                           [-0.71516603, -0.90594769],
                           [-0.71516603, -0.90594769]]]), output)
Beispiel #12
0
    def _testDecoderFPropHelper(self, params):
        """Computes decoder from params and computes loss with random inputs."""
        dec = decoder.AsrDecoder(params)
        src_seq_len = 5
        src_enc = tf.random_normal([src_seq_len, 2, 8],
                                   seed=982774838,
                                   dtype=py_utils.FPropDtype(params))
        src_enc_padding = tf.constant(
            [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
            dtype=py_utils.FPropDtype(params))
        encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                             padding=src_enc_padding)
        # shape=[4, 5]
        target_ids = tf.transpose(
            tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                         [5, 6, 7, 8], [10, 5, 2, 5]],
                        dtype=tf.int32))
        # shape=[4, 5]
        target_labels = tf.transpose(
            tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                         [5, 7, 8, 10], [10, 5, 2, 4]],
                        dtype=tf.int32))
        # shape=[4, 5]
        target_paddings = tf.transpose(
            tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0],
                         [0, 1, 0, 0], [1, 1, 1, 0]],
                        dtype=py_utils.FPropDtype(params)))
        target_transcripts = tf.constant(
            ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
        target_weights = 1.0 - target_paddings
        # ids/labels/weights/paddings are all in [batch, time] shape.
        targets = py_utils.NestedMap({
            'ids': target_ids,
            'labels': target_labels,
            'weights': target_weights,
            'paddings': target_paddings,
            'transcripts': target_transcripts,
        })
        metrics, per_sequence_loss = dec.FPropWithPerExampleLoss(
            encoder_outputs, targets)
        loss = metrics['loss']

        return loss, per_sequence_loss
    def ExtendStep(self,
                   theta,
                   source_vecs,
                   prefix_states,
                   aux_vecs=None,
                   aux_paddings=None,
                   t=None):
        """Transformer Layer, extend one step in decoding.

    This function is expected to be called during fast decoding of Transformer
    models.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_vecs: [source_batch, dim].
      prefix_states: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
      aux_vecs: [aux_time, aux_batch, dim]
      aux_paddings: [aux_time, aux_batch]
      t: a scalar, the current time step, 0-based.

    Returns:
      The attention context vector, [target_batch, source_dim]

      The attention probability vector, [source_time, target_batch]

      Updated prefix states
    """
        p = self.params

        if p.has_aux_atten:
            assert aux_vecs is not None
            assert aux_paddings is not None

        batch_size = tf.shape(source_vecs)[0]

        # First the self-attention layer.
        atten_vec, atten_prob, new_states = self.self_atten.ExtendStep(
            theta.self_atten, source_vecs, prefix_states, t)

        atten_vec = tf.expand_dims(atten_vec, axis=0)
        # Next the source attention layer.
        if p.has_aux_atten:
            atten_vec, atten_prob = self.atten.FProp(theta.atten, atten_vec,
                                                     aux_paddings, aux_vecs)

        # Finally, the feedforward layer.
        h = self.fflayer.FProp(
            theta.fflayer, atten_vec,
            tf.zeros([1, batch_size], dtype=py_utils.FPropDtype(p)))
        h = tf.squeeze(h, 0)
        return h, atten_prob, new_states
Beispiel #14
0
  def FPropTower(self, theta, input_batch):
    p = self.params
    tf.logging.info('input_batch=%r', input_batch)
    ids, paddings, labels_ids, weights = self._TrimIfPossible(
        input_batch.ids, input_batch.paddings, input_batch.labels,
        input_batch.weights)
    fprop_dtype = py_utils.FPropDtype(p)
    paddings = tf.cast(paddings, fprop_dtype)
    weights = tf.cast(weights, fprop_dtype)
    tf.logging.info('inputs={}'.format((ids, paddings, labels_ids, weights)))

    batch_size = tf.shape(ids)[0]
    state0 = None
    labels = py_utils.NestedMap(class_ids=labels_ids, class_weights=weights)
    fprop_kwargs = dict()
    if 'segment_ids' in input_batch:
      fprop_kwargs.update(
          segment_ids=input_batch.segment_ids,
          segment_pos=input_batch.segment_pos)
    xent_output, _ = self.lm.FProp(theta.lm, ids, paddings, state0, labels,
                                   **fprop_kwargs)

    if 'segment_ids' in input_batch:
      num_sentences = input_batch.num_sentences
    else:
      num_sentences = tf.ones(shape=[batch_size], dtype=tf.int32)
    # +num_sentences to account for the end of sequence symbol.
    num_words = tf.cast(
        tf.reduce_sum(input_batch.word_count + num_sentences), fprop_dtype)
    predicted_labels = tf.cast(xent_output.per_example_argmax, labels_ids.dtype)

    num_preds = xent_output.total_weight
    mean_acc = tf.reduce_sum(
        tf.cast(tf.equal(labels_ids, predicted_labels), fprop_dtype) *
        weights) / tf.math.maximum(num_preds, 1)
    loss = xent_output.avg_xent
    per_sequence_loss = tf.reduce_sum(
        xent_output.per_example_xent * weights, axis=1)
    if p.train.sum_loss_across_tokens_in_batch:
      loss = xent_output.total_xent
    else:
      per_sequence_loss /= tf.reduce_sum(weights, axis=1)
    return {
        'loss': (loss, num_preds),
        'fraction_of_correct_next_step_preds': (mean_acc, num_preds),
        'log_pplx': (xent_output.avg_xent, num_preds),
        'log_pplx_per_word': (xent_output.total_xent / num_words, num_words),
        'num_predictions': (num_preds, 1),
        'num_words': (num_words, 1),
        'num_sentences': (tf.reduce_sum(num_sentences), 1),
    }, {
        'loss': per_sequence_loss,
    }
Beispiel #15
0
 def _getDecoderFPropMetrics(self, params):
     """Creates decoder from params and computes metrics with random inputs."""
     dec = params.Instantiate()
     src_seq_len = 5
     src_enc = tf.random.normal([src_seq_len, 2, 8],
                                seed=982774838,
                                dtype=py_utils.FPropDtype(params))
     src_enc_padding = tf.constant(
         [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
         dtype=py_utils.FPropDtype(params))
     encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                          padding=src_enc_padding)
     # shape=[4, 5]
     target_ids = tf.transpose(
         tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                      [5, 6, 7, 8], [10, 5, 2, 5]],
                     dtype=tf.int32))
     # shape=[4, 5]
     target_labels = tf.transpose(
         tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                      [5, 7, 8, 10], [10, 5, 2, 4]],
                     dtype=tf.int32))
     # shape=[4, 5]
     target_paddings = tf.transpose(
         tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0],
                      [0, 1, 0, 0], [1, 1, 1, 0]],
                     dtype=py_utils.FPropDtype(params)))
     target_transcripts = tf.constant(
         ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
     target_weights = 1.0 - target_paddings
     # ids/labels/weights/paddings are all in [batch, time] shape.
     targets = py_utils.NestedMap({
         'ids': target_ids,
         'labels': target_labels,
         'weights': target_weights,
         'paddings': target_paddings,
         'transcripts': target_transcripts,
     })
     decoder_outputs = dec.FPropDefaultTheta(encoder_outputs, targets)
     return decoder_outputs.metrics, decoder_outputs.per_sequence['loss']
Beispiel #16
0
 def _UpdateVnConfig(self):
   """Update vn config from the various vn flags."""
   p = self.params
   tp = p.train
   if tp:
     vn_enabled = ((tp.vn_std > 0) and p.vn and
                   (p.vn.global_vn or p.vn.per_step_vn))
     if p.is_eval or (not vn_enabled):
       p.vn = py_utils.VariationalNoiseParams(None, False, False)
     else:
       # vn.scale is dependent on global_step.
       p.vn.scale = tf.cast(self._global_step > tp.vn_start_step,
                            py_utils.FPropDtype(p)) * tp.vn_std
Beispiel #17
0
    def testTransformerStackAlternateLayers(self):
        batch = 3
        tf.flags.FLAGS.tpu_compatible = True
        with self.session(use_gpu=False):
            model_dim = 2
            num_transformer_layers = 2
            transformer_tpl = layers_with_attention.TransformerLayer.Params()
            transformer_tpl.tr_atten_tpl.num_attention_heads = 1
            transformer_tpl.tr_fflayer_tpl.hidden_dim = 2

            params = mt_layers.TransformerStack.Params().Set(
                name='transformer',
                model_dim=model_dim,
                num_transformer_layers=num_transformer_layers,
                transformer_tpl=[
                    transformer_tpl.Copy()
                    for _ in range(num_transformer_layers)
                ],
                random_seed=123456)

            xformer = mt_layers.TransformerStack(params)
            input_arr = np.array([
                [[0, 1]] * batch,
                [[1, -1]] * batch,
            ],
                                 dtype=int)
            paddings_arr = np.array([[0] * batch, [0] * batch], dtype=int)
            inputs = tf.constant(input_arr.tolist(),
                                 dtype=py_utils.FPropDtype(params))
            paddings = tf.constant(paddings_arr.tolist(),
                                   dtype=py_utils.FPropDtype(params))
            output, _, _ = xformer.FProp(xformer.theta, inputs, paddings)

            self.evaluate(tf.global_variables_initializer())
            output = self.evaluate(output)
            print(repr(output))
            self.assertAllCloseAccordingToType(
                np.array([[[-0.940543, 1.479253]] * batch,
                          [[-0.413938, -2.550903]] * batch]), output)
Beispiel #18
0
    def FProp(self, theta, inp):
        """Look up style embedding."""

        p = self.params
        b_size = tf.shape(inp)[0]
        styles_w = tf.tile(tf.nn.tanh(theta.styles_w), [1, b_size, 1])
        styles_paddings = tf.zeros([p.num_styles, b_size],
                                   dtype=py_utils.FPropDtype(p))
        packed_src = self.atten.InitForSourcePacked(theta.atten, styles_w,
                                                    styles_w, styles_paddings)
        style_emb, probs, _ = self.atten.ComputeContextVectorWithSource(
            theta.atten, packed_src, inp)
        # TODO(yonghui): Extract and return the attention probabilities.
        return style_emb, probs
Beispiel #19
0
def _CreateSourceAndTargets(params):
    """Creates encoder outputs and targets from params for the decoder."""
    src_seq_len = 5
    src_enc = tf.random.normal([src_seq_len, 2, 8],
                               seed=982774838,
                               dtype=py_utils.FPropDtype(params))
    src_enc_padding = tf.constant(
        [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
        dtype=py_utils.FPropDtype(params))
    encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                         padding=src_enc_padding)
    # shape=[4, 5]
    target_ids = tf.transpose(
        tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15],
                     [5, 6, 7, 8], [10, 5, 2, 5]],
                    dtype=tf.int32))
    # shape=[4, 5]
    target_labels = tf.transpose(
        tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13],
                     [5, 7, 8, 10], [10, 5, 2, 4]],
                    dtype=tf.int32))
    # shape=[4, 5]
    target_paddings = tf.transpose(
        tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0],
                     [1, 1, 1, 0]],
                    dtype=py_utils.FPropDtype(params)))
    target_transcripts = tf.constant(['abcd', 'bcde', 'klmp', 'fghi', 'kfcf'])
    target_weights = 1.0 - target_paddings
    # ids/labels/weights/paddings are all in [batch, time] shape.
    targets = py_utils.NestedMap({
        'ids': target_ids,
        'labels': target_labels,
        'weights': target_weights,
        'paddings': target_paddings,
        'transcripts': target_transcripts,
    })
    return encoder_outputs, targets
Beispiel #20
0
    def FPropTower(self, theta, input_batch):
        p = self.params
        fprop_dtype = py_utils.FPropDtype(p)
        tf.logging.info('input_batch=%r', input_batch)
        ids = input_batch.ids
        labels_ids = input_batch.labels
        paddings = tf.cast(input_batch.paddings, fprop_dtype)
        weights = tf.cast(input_batch.weights, fprop_dtype)
        tf.logging.info('inputs={}'.format(
            (ids, paddings, labels_ids, weights)))

        batch_size = tf.shape(ids)[0]
        state0 = self.lm.zero_state(theta.lm, batch_size)
        labels = py_utils.NestedMap(class_ids=labels_ids,
                                    class_weights=weights)
        xent_output, _ = self.lm.FProp(theta.lm,
                                       ids,
                                       paddings,
                                       state0,
                                       labels,
                                       segment_ids=input_batch.segment_ids,
                                       segment_pos=input_batch.segment_pos)

        # +input_batch.num_sentences to account for the end of sequence symbol.
        num_words = tf.cast(
            tf.reduce_sum(input_batch.word_count +
                          tf.cast(input_batch.num_sentences, dtype=tf.int32)),
            fprop_dtype)
        predicted_labels = tf.cast(xent_output.per_example_argmax,
                                   labels_ids.dtype)
        num_sentences = tf.reduce_sum(input_batch.num_sentences)

        num_preds = xent_output.total_weight
        mean_acc = tf.reduce_sum(
            tf.cast(tf.equal(labels_ids, predicted_labels), fprop_dtype) *
            weights) / tf.math.maximum(num_preds, 1)
        loss = xent_output.avg_xent
        return {
            'loss': (loss, num_preds),
            'fraction_of_correct_next_step_preds': (mean_acc, num_preds),
            'log_pplx': (xent_output.avg_xent, num_preds),
            'log_pplx_per_word':
            (xent_output.total_xent / num_words, num_words),
            'num_predictions': (num_preds, 1),
            'num_words': (num_words, 1),
            'num_sentences': (num_sentences, 1)
        }, {}
Beispiel #21
0
  def FProp(self, theta, ids, segment_pos):
    p = self.params
    fprop_dtype = py_utils.FPropDtype(p)

    ids = self._MaybeSplit(ids)
    segment_pos = self._MaybeSplit(segment_pos)

    one_hot_ids = tf.one_hot(ids, p.vocab_size, dtype=fprop_dtype)
    one_hot_ids = self._MaybeSplit(one_hot_ids)

    one_hot_pos = tf.one_hot(segment_pos, p.max_len, dtype=fprop_dtype)
    one_hot_pos = self._MaybeSplit(one_hot_pos)

    token_emb = tf.einsum('VH,BLV->BLH', theta.embedding, one_hot_ids)
    token_emb = self._MaybeSplit(token_emb)

    pos_emb = tf.einsum('VH,BLV->BLH', theta.pos_emb, one_hot_pos)
    pos_emb = self._MaybeSplit(pos_emb)
    return self._MaybeSplit(token_emb + pos_emb)
Beispiel #22
0
    def zero_state(self, batch_size):
        """Returns the initial state given the batch size.

    Args:
      batch_size: the batch size.

    Returns:
      state0: A NestedMap of tensors including:
        - context: A Tensor of shape [b, filter_shape[0]-1, 1, c].
    """
        p = self.params
        assert p.filter_shape[1] == 1, (
            'zero_state() only supports 1d causal convolution.')

        context = tf.zeros(
            shape=[batch_size] +
            [p.filter_shape[0] - 1, p.filter_shape[1], p.filter_shape[2]],
            dtype=py_utils.FPropDtype(p))
        return py_utils.NestedMap(context=context)
Beispiel #23
0
    def _InitBeamSearchStateCallback(self, theta, encoder_outputs,
                                     num_hyps_per_beam):
        """Returns initial beams search states.

    Args:
      theta: a NestedMap of parameters.
      encoder_outputs: a NestedMap computed by encoder.
      num_hyps_per_beam: An int, number hyps to keep for source sentence.

    Returns:
      A tuple (initial_results, states).
        initial_results: a `.NestedMap` of initial results.
          atten_probs:
            The initial attention probs, of shape [tgt_batch, src_len].
        states: a `.NestedMap` of initial model states.
          rnn_states:
            Initial state of the RNN.
          atten_context:
            Initial attention context vector.
          atten_states:
            Initial attention state.
    """
        p = self.params
        num_beams = py_utils.GetShape(encoder_outputs.padding)[1]
        num_hyps = num_beams * num_hyps_per_beam
        rnn_states, init_atten_context, atten_probs, atten_states = (
            self._InitDecoder(theta, encoder_outputs, num_hyps))

        initial_results = py_utils.NestedMap(log_probs=tf.zeros(
            [num_hyps, p.softmax.num_classes], dtype=py_utils.FPropDtype(p)),
                                             atten_probs=atten_probs)

        return initial_results, py_utils.NestedMap({
            'rnn_states':
            rnn_states,
            'atten_context':
            init_atten_context,
            'atten_probs':
            atten_probs,
            'atten_states':
            atten_states,
        })
Beispiel #24
0
        def Callback(theta, encoder_outputs, num_hyps_per_beam):
            initial_results, states = self._InitBeamSearchStateCallback(
                theta, encoder_outputs, num_hyps_per_beam)
            assert hasattr(states, 'time_step')
            if tf.is_tensor(encoder_outputs.padding):
                batch_size = tf.shape(encoder_outputs.padding)[1]
            else:  # Required for multisource models.
                batch_size = tf.shape(
                    list(encoder_outputs.padding.values())[0])[1]
            num_hyps = batch_size * num_hyps_per_beam

            if biased:
                # states.consistent is initially all True
                states.consistent = tf.ones([
                    num_hyps,
                ], dtype=tf.bool)

            if stochastic:
                dtype = py_utils.FPropDtype(self.params)
                states.cumulative_log_probs = tf.zeros([num_hyps, 1],
                                                       dtype=dtype)
                states.perturbed_cumulative_log_probs = tf.zeros([num_hyps, 1],
                                                                 dtype=dtype)
                # Temporary tensors that store information passed from
                # PreBeamSearchStepCallback to PostBeamSearchStepCallback. These are
                # used for updating states.cumulative_log_probs and
                # states.perturbed_cumulative_log_probs for the next step, which
                # requires the knowledge of the chosen IDs, which only becomes available
                # after PreBeamSearchStepCallback.
                states.tmp_states = py_utils.NestedMap(
                    # Top-k (non-perturbed) log-probs. Used for updating
                    # `cumulative_log_probs` in PostBeamSearchStepCallback.
                    top_k_log_probs=tf.zeros([num_hyps, k], dtype=dtype),
                    # Vocab ID of each item of `top_k_log_probs`.
                    top_k_ids=tf.zeros([num_hyps, k], dtype=tf.int32),
                    # Perturbed cumulative log-probs of the top-k IDs. Used for updating
                    # `perturbed_cumulative_log_probs` in PostBeamSearchStepCallback.
                    new_perturbed_cumulative_log_probs=tf.zeros([num_hyps, k],
                                                                dtype=dtype),
                )

            return initial_results, states
Beispiel #25
0
  def FProp(self, theta, inputs, paddings, state0=None, labels=None):
    """Computes xent loss given the language model input activations.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: Input ids. An int32 tensor of shape [time, batch].
      paddings: A 0/1 tensor of shape [time, batch].
      state0: Not used for Transformer.
      labels: If not None, a `.NestedMap` containing the following fields:  -
        class_weights, a tensor with shape [time, batch] containing the weights
        for each target word. - class_ids, a tensor with shape [time, batch] of
        int32 dtype containing the target class labels. - class_probabilities, a
        tensor with shape [time, batch, vocab_size] of float values indicating
        class-membership probabilities.

    Returns:
      If `labels` is not None, returns (xent_output, state1), where
      `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return
      value and `state1` is the next recurrent state. Otherwise,
      `xent_output` only contains the softmax logits.
    """
    p = self.params
    ids = py_utils.HasRank(inputs, 2)
    paddings = py_utils.HasShape(paddings, tf.shape(ids))
    per_example_xent, logits = self.stack.FProp(
        theta.stack, ids, paddings, None, None, None, None,
        tf.cast(labels.class_ids, py_utils.FPropDtype(p)), labels.class_weights)
    per_example_argmax = py_utils.ArgMax(logits)
    total_xent = tf.reduce_sum(per_example_xent * labels.class_weights)
    total_weights = tf.reduce_sum(labels.class_weights)
    xent_output = py_utils.NestedMap(
        total_weight=total_weights,
        per_example_xent=per_example_xent,
        logits=logits,
        per_example_argmax=per_example_argmax,
        avg_xent=total_xent / total_weights,
        total_xent=total_xent)
    return xent_output, {}
Beispiel #26
0
    def StyleEmbFromProbs(self, theta, inp):
        """Look up style embedding based on feedin probabilities.

    Args:
      theta: params for this layer and its sub-layers.
      inp: attention probabilities of shape [batch_size, num_styles].

    Returns:
      style_emb - weighted combined style embedding based on inp.
    """
        p = self.params
        b_size = tf.shape(inp)[0]
        styles_w = tf.tile(tf.nn.tanh(theta.styles_w), [1, b_size, 1])
        styles_paddings = tf.zeros([p.num_styles, b_size],
                                   dtype=py_utils.FPropDtype(p))
        atten_probs = tf.tile(tf.expand_dims(inp, 1), [1, p.num_heads, 1])
        atten_probs = tf.reshape(atten_probs, [-1, p.num_styles])
        packed_src = self.atten.InitForSourcePacked(theta.atten, styles_w,
                                                    styles_w, styles_paddings)
        style_emb, _ = self.atten.ComputeContextVectorWithAttenProbs(
            theta.atten, packed_src.source_contexts, atten_probs)
        return style_emb
Beispiel #27
0
  def ZeroState(self, theta, prepared_inputs, batch_size):
    """Produce a zero state for this step.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: A set of inputs pre-processed by using
        PrepareExternalInputs.
      batch_size: Number of elements in the batched input.

    Returns:
      state0, a state parameter to pass to FProp on its first invocation.
    """
    max_seq_length = py_utils.GetShape(prepared_inputs.src, 3)[0]
    atten_state = self.atten.ZeroAttentionState(max_seq_length, batch_size)
    (new_atten_context, _,
     new_atten_states) = self.atten.ComputeContextVectorWithSource(
         theta.atten,
         prepared_inputs.packed_src,
         tf.zeros([batch_size, self.params.atten.query_dim],
                  dtype=py_utils.FPropDtype(self.params)),
         attention_state=atten_state)
    return py_utils.NestedMap(
        atten_context=new_atten_context, atten_state=new_atten_states)
Beispiel #28
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Beispiel #29
0
 def Cast(self, v):
     """Cast tensor dtype to fprop_dtype."""
     if not v.dtype.is_floating:
         return v
     return tf.cast(v, py_utils.FPropDtype(self.params))
Beispiel #30
0
    def FProp(self,
              theta,
              query_vec,
              source_paddings,
              source_vecs=None,
              query_segment_id=None,
              source_segment_id=None):
        """Transformer attention, residual and normalization layer.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      query_vec: [target_time, target_batch, dim]
      source_paddings: [source_time, source_batch]
      source_vecs: [source_time, source_batch, dim].
      query_segment_id: [target_time, target_batch]
      source_segment_id: [source_time, source_batch]
    Returns:
      (output, atten_probs). output is of shape [target_time, target_batch,
      source_dim], atten_probs is of shape [target_time, target_batch,
      source_time].
    """
        p = self.params
        unnormalized_query_vec = query_vec
        query_vec = self.layer_norm.FProp(theta.layer_norm, query_vec)

        if source_vecs is None:
            source_vecs = query_vec
            source_segment_id = query_segment_id

        if p.is_masked:
            assert source_vecs is not None
            query_vec = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(source_vecs),
                                            tf.shape(query_vec))
            ], query_vec)
            # Prepares mask for self-attention
            # [time, time]
            target_time = tf.shape(query_vec)[0]
            target_bs = tf.shape(query_vec)[1]
            triangle_padding = 1.0 - tf.matrix_band_part(
                tf.ones([target_time, target_time],
                        dtype=py_utils.FPropDtype(p)), -1, 0)
            # [time,  batch, time]
            causal_padding = tf.tile(tf.expand_dims(triangle_padding, 1),
                                     [1, target_bs, 1])

            causal_padding = tf.reshape(causal_padding, [-1, target_time])
        else:
            causal_padding = None

        query_dim = tf.shape(query_vec)[-1]
        packed_src = self.atten.PackSource(theta.atten, source_vecs,
                                           source_vecs, source_paddings,
                                           source_segment_id)

        if query_segment_id is not None:
            query_segment_id = tf.reshape(query_segment_id, [-1])
        ctx_vec, atten_prob, _ = self.atten.ComputeContextVectorWithSource(
            theta.atten,
            packed_src,
            tf.reshape(query_vec, [-1, query_dim]),
            per_step_source_padding=causal_padding,
            query_segment_id=query_segment_id)
        ctx_vec = self.residual_dropout.FProp(theta.residual_dropout, ctx_vec)
        input_to_add = (unnormalized_query_vec
                        if p.add_unnormalized_input else query_vec)
        h = input_to_add + tf.reshape(ctx_vec, tf.shape(query_vec))
        atten_prob = tf.reshape(atten_prob, [
            tf.shape(query_vec)[0],
            tf.shape(query_vec)[1],
            tf.shape(source_vecs)[0]
        ])
        return h, atten_prob