示例#1
0
 def testMergeAll(self):
   with self.session() as sess:
     model = MockTransformer()
     x = tf.constant(0, dtype=tf.float32)
     y = tf.constant(0, dtype=tf.int64)
     with tpu_summary.context():
       x, y = model.FProp(x, y)
       summaries = tpu_summary.merge_all()
     x, y, summaries = sess.run((x, y, summaries))
     self.assertEqual((3.0, 3), (x, y))
     expected = {
         'x_mean/decoder000': 3.0,
         'x_mean/decoder001': 3.0,
         'x_mean/decoder002': 3.0,
         'x_mean/encoder000': 1.0,
         'x_mean/encoder001': 2.0,
         'x_mean/encoder002': 3.0,
         'y_mean/decoder000': 1,
         'y_mean/decoder001': 2,
         'y_mean/decoder002': 3,
         'y_mean/encoder000': 0,
         'y_mean/encoder001': 0,
         'y_mean/encoder002': 0,
     }
     self.assertEqual(expected, self._CanonicalizeSummaryName(summaries))
示例#2
0
    def _config_outfeed(self, xformer, infeed_batch):
        """Setup the outfeed ops."""
        fprop_dtype = py_utils.FPropDtype(self.model_params.task)

        assert len(infeed_batch) == 6 or len(infeed_batch) == 7, len(
            infeed_batch)
        if len(infeed_batch) == 7:
            (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels, _,
             _) = infeed_batch
        elif len(infeed_batch) == 6:
            (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels,
             _) = infeed_batch
        tgt_segment_id = tf.cast(tgt_segment_id, fprop_dtype)

        input_batch = py_utils.NestedMap()
        input_batch.src = py_utils.NestedMap()
        input_batch.src.ids = (0 * tgt_ids)  # unused
        input_batch.src.segment_ids = (0 * tgt_segment_id)  # unused
        input_batch.src.segment_pos = (0 * tgt_segment_pos)  # unused
        input_batch.tgt = py_utils.NestedMap()
        input_batch.tgt.ids = tgt_ids
        input_batch.tgt.segment_ids = tgt_segment_id
        input_batch.tgt.segment_pos = tgt_segment_pos
        input_batch.tgt.labels = tgt_labels  # only used when --fprop=true

        with tpu_summary.context(rewrite_while_loop=True):
            dec_ret = xformer.DecodeIds(xformer.theta, input_batch)
            dec_metrics = tpu_summary.merge_all()
            key = infeed_batch[0]
            return [
                key, tgt_ids, tgt_segment_id, dec_ret.topk_ids,
                dec_ret.topk_lens, dec_ret.topk_scores, dec_metrics
            ]
示例#3
0
 def testWhileLoopReduceSum(self):
   with self.session() as sess:
     model = MockTransformer()
     x = tf.constant(0, dtype=tf.float32)
     y = tf.constant(0, dtype=tf.int64)
     with tpu_summary.context(rewrite_while_loop=True):
       x, y = model.BeamSearch(x, y, decoder_reduce_sum=True)
       summaries = tpu_summary.merge_all()
     tf.logging.info('summaries=%r', summaries)
     x, y, summaries = sess.run((x, y, summaries))
     self.assertEqual((3.0, 30), (x, y))
     expected = {
         'x_mean/encoder000': 1.0,
         'x_mean/encoder001': 2.0,
         'x_mean/encoder002': 3.0,
         'y_mean/encoder000': 0,
         'y_mean/encoder001': 0,
         'y_mean/encoder002': 0,
         'x_mean/decoder000': 30.0,
         'x_mean/decoder001': 30.0,
         'x_mean/decoder002': 30.0,
         'y_mean/decoder000': 145.0,
         'y_mean/decoder001': 155.0,
         'y_mean/decoder002': 165.0,
     }
     self.assertEqual(expected, self._CanonicalizeSummaryName(summaries))
示例#4
0
 def testWhileLoopNoRewrite(self):
   with self.session() as sess:
     model = MockTransformer()
     x = tf.constant(0, dtype=tf.float32)
     y = tf.constant(0, dtype=tf.int64)
     with tpu_summary.context():
       x, y = model.BeamSearch(x, y)
       # ValueError: Tensor decoder000/Mean:0 is not an element of this graph.
       with self.assertRaises(ValueError):
         summaries = tpu_summary.merge_all()
         x, y, summaries = sess.run((x, y, summaries))
示例#5
0
    def testLayerStackSummary(self):
        # In this test we very that summaries created inside stack layers
        # are processed properly with and without RepeatedLayer
        model_dim = 4
        num_heads = 2
        d_kv = 2
        d_ff = 8
        num_experts = 2
        builder = gshard_builder.DenseBuilder.Params().Set(
            deterministic_dropout=True,
            dtype=tf.float32,
            relative_attention_type='bias',
            model_dim=model_dim,
            attention_num_heads=num_heads,
            attention_combine_dims=True,
            attention_num_memory_heads=1,
            model_dim_reshape_segments=None,
            ff_dim=d_ff,
            moe_hidden_dim=d_ff,
            e_dim=num_experts,
            c_dim=1,
            num_groups=num_experts,
            num_devices=num_experts,
            attention_key_value_dim=d_kv).Instantiate()

        def _GetOutputs(enc, dec):
            x, seg_id, pos_id = self._GetInputs()
            enc_inputs = py_utils.NestedMap(vec=x,
                                            segment_id=seg_id,
                                            segment_pos=pos_id,
                                            aux_loss=tf.constant(0.0))
            enc_outs = enc.FPropDefaultTheta(enc_inputs)
            dec_inputs = py_utils.NestedMap(
                vec=x,
                segment_id=seg_id,
                segment_pos=pos_id,
                encoder_output=enc_outs.vec,
                encoder_segment_id=tf.zeros_like(seg_id),
                encoder_segment_pos=tf.zeros_like(pos_id),
                aux_loss=enc_outs.aux_loss)
            return dec.FPropDefaultTheta(dec_inputs).vec

        # Build a graph with RepeatLayer unrolled.
        g = tf.Graph()
        with g.as_default(), tpu_summary.context(), cluster_factory.SetEval(
                mode=True):
            tf.random.set_seed(None)
            enc = builder.EncoderLayerStack(
                'encoder',
                sub_layers=[builder.DenseReluDense('ffw')],
                num=2,
                use_repeat_layer=True).Instantiate()
            dec = builder.DecoderLayerStack(
                'decoder',
                sub_layers=[builder.MoE('moe', decoder=True)],
                num=2,
                use_repeat_layer=True).Instantiate()
            rep_unroll_out = _GetOutputs(enc, dec)
            rep_unroll_summary = tpu_summary.merge_all()

        expected_rep_unroll_summary = [
            'index_1/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating',
            'index_1/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating',
            'over_capacity_1_ratio/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_2_ratio/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'top1_expert/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating',
            'top1_expert/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating'
        ]
        self.assertCountEqual(expected_rep_unroll_summary, rep_unroll_summary)

        tf.Session.reset(target='')
        with tf.Session(graph=g) as sess:
            sess.run(tf.global_variables_initializer())
            rep_unroll_out, rep_unroll_summary = sess.run(
                [rep_unroll_out, rep_unroll_summary])
            var_values = sess.run(tf.trainable_variables())
        # Build a graph without RepeatLayer.
        g = tf.Graph()
        with g.as_default(), tpu_summary.context():
            tf.random.set_seed(None)
            enc = builder.EncoderLayerStack('encoder',
                                            sub_layers=[
                                                builder.DenseReluDense('ffw')
                                            ],
                                            num=2).Instantiate()
            dec = builder.DecoderLayerStack(
                'decoder',
                sub_layers=[builder.MoE('moe', decoder=True)],
                num=2).Instantiate()
            dec_out = _GetOutputs(enc, dec)
            dec_summary = tpu_summary.merge_all()

        expected_dec_summary = [
            'index_1/decoder_1/layer_000/moe/ffw/compute_gating',
            'index_1/decoder_1/layer_001/moe/ffw/compute_gating',
            'over_capacity_1_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity',
            'over_capacity_1_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity',
            'over_capacity_2_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1',
            'over_capacity_2_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1',
            'top1_expert/decoder_1/layer_000/moe/ffw/compute_gating',
            'top1_expert/decoder_1/layer_001/moe/ffw/compute_gating'
        ]
        self.assertCountEqual(expected_dec_summary, dec_summary)

        tf.Session.reset(target='')
        with tf.Session(graph=g) as sess:
            tf_vars = [
                enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi,
                enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale,
                enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo,
                enc.vars.final_layer_norm.w.scale,
                dec.vars.layer_000.ln.w.scale, dec.vars.layer_000.moe.moe.wi,
                dec.vars.layer_000.moe.moe.wo,
                dec.vars.layer_000.moe.ffw.top_2_gating.w,
                dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.moe.moe.wi,
                dec.vars.layer_001.moe.moe.wo,
                dec.vars.layer_001.moe.ffw.top_2_gating.w,
                dec.vars.final_layer_norm.w.scale
            ]
            for val, var in zip(var_values, tf_vars):
                sess.run(tf.assign(var, val))
            dec_out, dec_summary = sess.run([dec_out, dec_summary])
            self.assertAllClose(dec_out, rep_unroll_out)

            for name, alt_name in zip(expected_dec_summary,
                                      expected_rep_unroll_summary):
                self.assertAllClose(dec_summary[name],
                                    rep_unroll_summary[alt_name])
  def testFlatBeamSearchWithExtensionBuffer(self, rule):
    batch_size = 2
    beam_size = 4
    ext_size = 128
    nbest_size = 8
    max_steps = 300
    vocab_size = 100

    decoder = TestDecoder(batch_size, beam_size, max_steps, vocab_size, rule)
    dec_state = decoder.new_state()
    dec_callback = decoder.dec_callback

    with tpu_summary.context(rewrite_while_loop=True):
      bs = flat_beam_search_helper.flat_beam_search(
          batch_size,
          beam_size,
          max_steps,
          dec_callback,
          dec_state,
          bos_id=1,
          eos_id=0,
          beam_gap=None,
          ext_size=ext_size,
          nbest_size=nbest_size,
          debug=True)
      debug_tensors = tpu_summary.merge_all()

    tf.logging.info('bs=%r', bs)
    tf.logging.info('debug_tensors=%r', debug_tensors)

    with self.session() as sess:
      [bs, debug_tensors] = sess.run([bs, debug_tensors])

    tf.logging.info('bs=%r', bs)

    loop_vars, dec_state_, nbest = bs
    (topk_ids, topk_lens, topk_scores) = nbest
    del loop_vars, dec_state_, nbest

    self.assertEqual((batch_size, nbest_size, max_steps), topk_ids.shape)
    self.assertEqual((batch_size, nbest_size), topk_lens.shape)
    self.assertEqual((batch_size, nbest_size), topk_scores.shape)

    print('Decoder output rule=%r' % decoder.rule)
    print('batch_size=%d beam_size=%d ext_size=%d nbest_size=%d max_steps=%d' %
          (batch_size, beam_size, ext_size, nbest_size, max_steps))

    topk = [[
        topk_ids[b, k, 0:topk_lens[b, k]].tolist() for k in range(nbest_size)
    ] for b in range(batch_size)]

    for b in range(batch_size):
      for k in range(nbest_size):
        print('topk[%d][%d] (%0.6f): %r' %
              (b, k, topk_scores[b, k], topk[b][k]))

    # pyformat: disable
    if decoder.rule == '+1':
      expected = 2 * [[
          [1, 2, 3, 4, 5, 6, 7, 8, 9, 0],
          [1, 2, 3, 4, 5, 6, 7, 9, 0],
          [1, 2, 3, 4, 5, 6, 8, 9, 0],
          [1, 2, 3, 4, 5, 7, 8, 9, 0],
          [1, 2, 3, 4, 6, 7, 8, 9, 0],
          [1, 2, 3, 5, 6, 7, 8, 9, 0],
          [1, 2, 4, 5, 6, 7, 8, 9, 0],
          [1, 3, 4, 5, 6, 7, 8, 9, 0],
      ]]
    elif decoder.rule == 'sum':
      expected = 2 * [[
          [1, 1, 2, 4, 9, 0],
          [1, 1, 2, 5, 9, 0],
          [1, 2, 3, 6, 12, 24, 48, 96, 0],
          [1, 1, 2, 4, 8, 16, 32, 64, 29, 0],
          [1, 1, 2, 4, 8, 16, 32, 65, 29, 0],
          [1, 1, 2, 5, 10, 19, 0],
          [1, 2, 3, 6, 12, 25, 49, 0],
          [1, 2, 3, 6, 12, 24, 49, 0],
      ]]
    elif decoder.rule == 'fib':
      expected = 2 * [[
          [1, 1, 2, 3, 5, 9, 0],
          [1, 1, 2, 3, 6, 9, 0],
          [1, 2, 3, 5, 9, 0],
          [1, 1, 4, 5, 9, 0],
          [1, 2, 3, 6, 9, 0],
          [1, 1, 3, 4, 7, 11, 18, 29, 0],
          [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 0],
          [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 0],
      ]]
    # pyformat: enable

    self.assertEqual(expected, topk)
  def testFlatBeamSearchWithPrefix(self, rule):
    batch_size = 2
    beam_size = 4
    max_steps = 20
    vocab_size = 100
    prefix_size = 4

    prefix_len = np.zeros([batch_size])
    prefix_len[-2] = 2
    prefix_len[-1] = 3
    prefix_id = np.zeros([batch_size, prefix_size])
    prefix_id[0, -2:] = [11, 12]
    prefix_id[1, -3:] = [21, 22, 23]

    decoder = TestDecoder(batch_size, beam_size, max_steps, vocab_size, rule)
    dec_state = decoder.new_state()
    dec_callback = decoder.dec_callback

    with tpu_summary.context(rewrite_while_loop=True):
      bs = flat_beam_search_helper.flat_beam_search(
          batch_size,
          beam_size,
          max_steps,
          dec_callback,
          dec_state,
          bos_id=1,
          eos_id=0,
          prefix=prefix_id,
          prefix_len=prefix_len,
          beam_gap=None,
          debug=True)
      debug_tensors = tpu_summary.merge_all()

    tf.logging.info('bs=%r', bs)
    tf.logging.info('debug_tensors=%r', debug_tensors)

    with self.session() as sess:
      [bs, debug_tensors] = sess.run([bs, debug_tensors])

    tf.logging.info('bs=%r', bs)

    loop_vars, dec_state_, nbest = bs
    (topk_ids, topk_lens, topk_scores) = nbest
    del loop_vars, dec_state_

    self.assertEqual((batch_size, beam_size, max_steps + prefix_size),
                     topk_ids.shape)
    self.assertEqual((batch_size, beam_size), topk_lens.shape)
    self.assertEqual((batch_size, beam_size), topk_scores.shape)

    print('Decoder output rule=%r' % decoder.rule)
    print('batch_size=%d beam_size=%d max_steps=%d' %
          (batch_size, beam_size, max_steps))

    topk = [[
        topk_ids[b, k, 0:topk_lens[b, k]].tolist() for k in range(beam_size)
    ] for b in range(batch_size)]

    for b in range(batch_size):
      for k in range(beam_size):
        print('topk[%d][%d] (%0.6f): %r' %
              (b, k, topk_scores[b, k], topk[b][k]))

    # pyformat: disable
    if decoder.rule == '+1':
      expected = [
          [[11, 12, 13, 14, 15, 16, 17, 18, 19, 0],
           [11, 12, 13, 14, 15, 16, 17, 19, 0],
           [11, 12, 13, 14, 15, 16, 18, 19, 0],
           [11, 12, 13, 14, 15, 17, 18, 19, 0]],
          [[21, 22, 23, 24, 25, 26, 27, 28, 29, 0],
           [21, 22, 23, 24, 25, 26, 27, 29, 0],
           [21, 22, 23, 24, 25, 26, 28, 29, 0],
           [21, 22, 23, 24, 25, 27, 28, 29, 0]]]
    elif decoder.rule == 'sum':
      expected = [
          [[11, 12, 23, 46, 92, 0],
           [11, 12, 23, 46, 93, 0],
           [11, 12, 23, 47, 93, 0],
           [11, 12, 24, 47, 94, 0]],
          [[21, 22, 23, 66, 32, 64, 29, 0],
           [21, 22, 23, 66, 32, 65, 29, 0],
           [21, 22, 23, 66, 32, 64, 28, 56, 12, 24, 48, 96, 0],
           [21, 22, 23, 69, 0]]]
    elif decoder.rule == 'fib':
      expected = [
          [[11, 12, 23, 35, 58, 93, 0],
           [11, 12, 23, 35, 59, 0],
           [11, 12, 23, 36, 59, 0],
           [11, 12, 23, 35, 58, 94, 0]],
          [[21, 22, 23, 45, 69, 0],
           [21, 22, 23, 46, 69, 0],
           [21, 22, 23, 45, 68, 13, 81, 94, 0],
           [21, 22, 23, 45, 68, 13, 81, 95, 0]]]
    # pyformat: enable

    # locals().update({k.split('/')[0]:v for k,v in debug_tensors.items()})
    # import ipdb; ipdb.set_trace()

    self.assertEqual(expected, topk)