def testMergeAll(self): with self.session() as sess: model = MockTransformer() x = tf.constant(0, dtype=tf.float32) y = tf.constant(0, dtype=tf.int64) with tpu_summary.context(): x, y = model.FProp(x, y) summaries = tpu_summary.merge_all() x, y, summaries = sess.run((x, y, summaries)) self.assertEqual((3.0, 3), (x, y)) expected = { 'x_mean/decoder000': 3.0, 'x_mean/decoder001': 3.0, 'x_mean/decoder002': 3.0, 'x_mean/encoder000': 1.0, 'x_mean/encoder001': 2.0, 'x_mean/encoder002': 3.0, 'y_mean/decoder000': 1, 'y_mean/decoder001': 2, 'y_mean/decoder002': 3, 'y_mean/encoder000': 0, 'y_mean/encoder001': 0, 'y_mean/encoder002': 0, } self.assertEqual(expected, self._CanonicalizeSummaryName(summaries))
def _config_outfeed(self, xformer, infeed_batch): """Setup the outfeed ops.""" fprop_dtype = py_utils.FPropDtype(self.model_params.task) assert len(infeed_batch) == 6 or len(infeed_batch) == 7, len( infeed_batch) if len(infeed_batch) == 7: (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels, _, _) = infeed_batch elif len(infeed_batch) == 6: (key, tgt_ids, tgt_segment_id, tgt_segment_pos, tgt_labels, _) = infeed_batch tgt_segment_id = tf.cast(tgt_segment_id, fprop_dtype) input_batch = py_utils.NestedMap() input_batch.src = py_utils.NestedMap() input_batch.src.ids = (0 * tgt_ids) # unused input_batch.src.segment_ids = (0 * tgt_segment_id) # unused input_batch.src.segment_pos = (0 * tgt_segment_pos) # unused input_batch.tgt = py_utils.NestedMap() input_batch.tgt.ids = tgt_ids input_batch.tgt.segment_ids = tgt_segment_id input_batch.tgt.segment_pos = tgt_segment_pos input_batch.tgt.labels = tgt_labels # only used when --fprop=true with tpu_summary.context(rewrite_while_loop=True): dec_ret = xformer.DecodeIds(xformer.theta, input_batch) dec_metrics = tpu_summary.merge_all() key = infeed_batch[0] return [ key, tgt_ids, tgt_segment_id, dec_ret.topk_ids, dec_ret.topk_lens, dec_ret.topk_scores, dec_metrics ]
def testWhileLoopReduceSum(self): with self.session() as sess: model = MockTransformer() x = tf.constant(0, dtype=tf.float32) y = tf.constant(0, dtype=tf.int64) with tpu_summary.context(rewrite_while_loop=True): x, y = model.BeamSearch(x, y, decoder_reduce_sum=True) summaries = tpu_summary.merge_all() tf.logging.info('summaries=%r', summaries) x, y, summaries = sess.run((x, y, summaries)) self.assertEqual((3.0, 30), (x, y)) expected = { 'x_mean/encoder000': 1.0, 'x_mean/encoder001': 2.0, 'x_mean/encoder002': 3.0, 'y_mean/encoder000': 0, 'y_mean/encoder001': 0, 'y_mean/encoder002': 0, 'x_mean/decoder000': 30.0, 'x_mean/decoder001': 30.0, 'x_mean/decoder002': 30.0, 'y_mean/decoder000': 145.0, 'y_mean/decoder001': 155.0, 'y_mean/decoder002': 165.0, } self.assertEqual(expected, self._CanonicalizeSummaryName(summaries))
def testWhileLoopNoRewrite(self): with self.session() as sess: model = MockTransformer() x = tf.constant(0, dtype=tf.float32) y = tf.constant(0, dtype=tf.int64) with tpu_summary.context(): x, y = model.BeamSearch(x, y) # ValueError: Tensor decoder000/Mean:0 is not an element of this graph. with self.assertRaises(ValueError): summaries = tpu_summary.merge_all() x, y, summaries = sess.run((x, y, summaries))
def testLayerStackSummary(self): # In this test we very that summaries created inside stack layers # are processed properly with and without RepeatedLayer model_dim = 4 num_heads = 2 d_kv = 2 d_ff = 8 num_experts = 2 builder = gshard_builder.DenseBuilder.Params().Set( deterministic_dropout=True, dtype=tf.float32, relative_attention_type='bias', model_dim=model_dim, attention_num_heads=num_heads, attention_combine_dims=True, attention_num_memory_heads=1, model_dim_reshape_segments=None, ff_dim=d_ff, moe_hidden_dim=d_ff, e_dim=num_experts, c_dim=1, num_groups=num_experts, num_devices=num_experts, attention_key_value_dim=d_kv).Instantiate() def _GetOutputs(enc, dec): x, seg_id, pos_id = self._GetInputs() enc_inputs = py_utils.NestedMap(vec=x, segment_id=seg_id, segment_pos=pos_id, aux_loss=tf.constant(0.0)) enc_outs = enc.FPropDefaultTheta(enc_inputs) dec_inputs = py_utils.NestedMap( vec=x, segment_id=seg_id, segment_pos=pos_id, encoder_output=enc_outs.vec, encoder_segment_id=tf.zeros_like(seg_id), encoder_segment_pos=tf.zeros_like(pos_id), aux_loss=enc_outs.aux_loss) return dec.FPropDefaultTheta(dec_inputs).vec # Build a graph with RepeatLayer unrolled. g = tf.Graph() with g.as_default(), tpu_summary.context(), cluster_factory.SetEval( mode=True): tf.random.set_seed(None) enc = builder.EncoderLayerStack( 'encoder', sub_layers=[builder.DenseReluDense('ffw')], num=2, use_repeat_layer=True).Instantiate() dec = builder.DecoderLayerStack( 'decoder', sub_layers=[builder.MoE('moe', decoder=True)], num=2, use_repeat_layer=True).Instantiate() rep_unroll_out = _GetOutputs(enc, dec) rep_unroll_summary = tpu_summary.merge_all() expected_rep_unroll_summary = [ 'index_1/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating', 'index_1/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating', 'over_capacity_1_ratio/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1_ratio/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_2_ratio/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2_ratio/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating/over_capacity_1', 'top1_expert/decoder_1/blocks/blocks_body/layer_000/moe/ffw/compute_gating', 'top1_expert/decoder_1/blocks/blocks_body_1/layer_000/moe/ffw/compute_gating' ] self.assertCountEqual(expected_rep_unroll_summary, rep_unroll_summary) tf.Session.reset(target='') with tf.Session(graph=g) as sess: sess.run(tf.global_variables_initializer()) rep_unroll_out, rep_unroll_summary = sess.run( [rep_unroll_out, rep_unroll_summary]) var_values = sess.run(tf.trainable_variables()) # Build a graph without RepeatLayer. g = tf.Graph() with g.as_default(), tpu_summary.context(): tf.random.set_seed(None) enc = builder.EncoderLayerStack('encoder', sub_layers=[ builder.DenseReluDense('ffw') ], num=2).Instantiate() dec = builder.DecoderLayerStack( 'decoder', sub_layers=[builder.MoE('moe', decoder=True)], num=2).Instantiate() dec_out = _GetOutputs(enc, dec) dec_summary = tpu_summary.merge_all() expected_dec_summary = [ 'index_1/decoder_1/layer_000/moe/ffw/compute_gating', 'index_1/decoder_1/layer_001/moe/ffw/compute_gating', 'over_capacity_1_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity', 'over_capacity_1_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity', 'over_capacity_2_ratio/decoder_1/layer_000/moe/ffw/compute_gating/over_capacity_1', 'over_capacity_2_ratio/decoder_1/layer_001/moe/ffw/compute_gating/over_capacity_1', 'top1_expert/decoder_1/layer_000/moe/ffw/compute_gating', 'top1_expert/decoder_1/layer_001/moe/ffw/compute_gating' ] self.assertCountEqual(expected_dec_summary, dec_summary) tf.Session.reset(target='') with tf.Session(graph=g) as sess: tf_vars = [ enc.vars.layer_000.ln.w.scale, enc.vars.layer_000.ffw.w.wi, enc.vars.layer_000.ffw.w.wo, enc.vars.layer_001.ln.w.scale, enc.vars.layer_001.ffw.w.wi, enc.vars.layer_001.ffw.w.wo, enc.vars.final_layer_norm.w.scale, dec.vars.layer_000.ln.w.scale, dec.vars.layer_000.moe.moe.wi, dec.vars.layer_000.moe.moe.wo, dec.vars.layer_000.moe.ffw.top_2_gating.w, dec.vars.layer_001.ln.w.scale, dec.vars.layer_001.moe.moe.wi, dec.vars.layer_001.moe.moe.wo, dec.vars.layer_001.moe.ffw.top_2_gating.w, dec.vars.final_layer_norm.w.scale ] for val, var in zip(var_values, tf_vars): sess.run(tf.assign(var, val)) dec_out, dec_summary = sess.run([dec_out, dec_summary]) self.assertAllClose(dec_out, rep_unroll_out) for name, alt_name in zip(expected_dec_summary, expected_rep_unroll_summary): self.assertAllClose(dec_summary[name], rep_unroll_summary[alt_name])
def testFlatBeamSearchWithExtensionBuffer(self, rule): batch_size = 2 beam_size = 4 ext_size = 128 nbest_size = 8 max_steps = 300 vocab_size = 100 decoder = TestDecoder(batch_size, beam_size, max_steps, vocab_size, rule) dec_state = decoder.new_state() dec_callback = decoder.dec_callback with tpu_summary.context(rewrite_while_loop=True): bs = flat_beam_search_helper.flat_beam_search( batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=0, beam_gap=None, ext_size=ext_size, nbest_size=nbest_size, debug=True) debug_tensors = tpu_summary.merge_all() tf.logging.info('bs=%r', bs) tf.logging.info('debug_tensors=%r', debug_tensors) with self.session() as sess: [bs, debug_tensors] = sess.run([bs, debug_tensors]) tf.logging.info('bs=%r', bs) loop_vars, dec_state_, nbest = bs (topk_ids, topk_lens, topk_scores) = nbest del loop_vars, dec_state_, nbest self.assertEqual((batch_size, nbest_size, max_steps), topk_ids.shape) self.assertEqual((batch_size, nbest_size), topk_lens.shape) self.assertEqual((batch_size, nbest_size), topk_scores.shape) print('Decoder output rule=%r' % decoder.rule) print('batch_size=%d beam_size=%d ext_size=%d nbest_size=%d max_steps=%d' % (batch_size, beam_size, ext_size, nbest_size, max_steps)) topk = [[ topk_ids[b, k, 0:topk_lens[b, k]].tolist() for k in range(nbest_size) ] for b in range(batch_size)] for b in range(batch_size): for k in range(nbest_size): print('topk[%d][%d] (%0.6f): %r' % (b, k, topk_scores[b, k], topk[b][k])) # pyformat: disable if decoder.rule == '+1': expected = 2 * [[ [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], [1, 2, 3, 4, 5, 6, 7, 9, 0], [1, 2, 3, 4, 5, 6, 8, 9, 0], [1, 2, 3, 4, 5, 7, 8, 9, 0], [1, 2, 3, 4, 6, 7, 8, 9, 0], [1, 2, 3, 5, 6, 7, 8, 9, 0], [1, 2, 4, 5, 6, 7, 8, 9, 0], [1, 3, 4, 5, 6, 7, 8, 9, 0], ]] elif decoder.rule == 'sum': expected = 2 * [[ [1, 1, 2, 4, 9, 0], [1, 1, 2, 5, 9, 0], [1, 2, 3, 6, 12, 24, 48, 96, 0], [1, 1, 2, 4, 8, 16, 32, 64, 29, 0], [1, 1, 2, 4, 8, 16, 32, 65, 29, 0], [1, 1, 2, 5, 10, 19, 0], [1, 2, 3, 6, 12, 25, 49, 0], [1, 2, 3, 6, 12, 24, 49, 0], ]] elif decoder.rule == 'fib': expected = 2 * [[ [1, 1, 2, 3, 5, 9, 0], [1, 1, 2, 3, 6, 9, 0], [1, 2, 3, 5, 9, 0], [1, 1, 4, 5, 9, 0], [1, 2, 3, 6, 9, 0], [1, 1, 3, 4, 7, 11, 18, 29, 0], [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 0], [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 0], ]] # pyformat: enable self.assertEqual(expected, topk)
def testFlatBeamSearchWithPrefix(self, rule): batch_size = 2 beam_size = 4 max_steps = 20 vocab_size = 100 prefix_size = 4 prefix_len = np.zeros([batch_size]) prefix_len[-2] = 2 prefix_len[-1] = 3 prefix_id = np.zeros([batch_size, prefix_size]) prefix_id[0, -2:] = [11, 12] prefix_id[1, -3:] = [21, 22, 23] decoder = TestDecoder(batch_size, beam_size, max_steps, vocab_size, rule) dec_state = decoder.new_state() dec_callback = decoder.dec_callback with tpu_summary.context(rewrite_while_loop=True): bs = flat_beam_search_helper.flat_beam_search( batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=0, prefix=prefix_id, prefix_len=prefix_len, beam_gap=None, debug=True) debug_tensors = tpu_summary.merge_all() tf.logging.info('bs=%r', bs) tf.logging.info('debug_tensors=%r', debug_tensors) with self.session() as sess: [bs, debug_tensors] = sess.run([bs, debug_tensors]) tf.logging.info('bs=%r', bs) loop_vars, dec_state_, nbest = bs (topk_ids, topk_lens, topk_scores) = nbest del loop_vars, dec_state_ self.assertEqual((batch_size, beam_size, max_steps + prefix_size), topk_ids.shape) self.assertEqual((batch_size, beam_size), topk_lens.shape) self.assertEqual((batch_size, beam_size), topk_scores.shape) print('Decoder output rule=%r' % decoder.rule) print('batch_size=%d beam_size=%d max_steps=%d' % (batch_size, beam_size, max_steps)) topk = [[ topk_ids[b, k, 0:topk_lens[b, k]].tolist() for k in range(beam_size) ] for b in range(batch_size)] for b in range(batch_size): for k in range(beam_size): print('topk[%d][%d] (%0.6f): %r' % (b, k, topk_scores[b, k], topk[b][k])) # pyformat: disable if decoder.rule == '+1': expected = [ [[11, 12, 13, 14, 15, 16, 17, 18, 19, 0], [11, 12, 13, 14, 15, 16, 17, 19, 0], [11, 12, 13, 14, 15, 16, 18, 19, 0], [11, 12, 13, 14, 15, 17, 18, 19, 0]], [[21, 22, 23, 24, 25, 26, 27, 28, 29, 0], [21, 22, 23, 24, 25, 26, 27, 29, 0], [21, 22, 23, 24, 25, 26, 28, 29, 0], [21, 22, 23, 24, 25, 27, 28, 29, 0]]] elif decoder.rule == 'sum': expected = [ [[11, 12, 23, 46, 92, 0], [11, 12, 23, 46, 93, 0], [11, 12, 23, 47, 93, 0], [11, 12, 24, 47, 94, 0]], [[21, 22, 23, 66, 32, 64, 29, 0], [21, 22, 23, 66, 32, 65, 29, 0], [21, 22, 23, 66, 32, 64, 28, 56, 12, 24, 48, 96, 0], [21, 22, 23, 69, 0]]] elif decoder.rule == 'fib': expected = [ [[11, 12, 23, 35, 58, 93, 0], [11, 12, 23, 35, 59, 0], [11, 12, 23, 36, 59, 0], [11, 12, 23, 35, 58, 94, 0]], [[21, 22, 23, 45, 69, 0], [21, 22, 23, 46, 69, 0], [21, 22, 23, 45, 68, 13, 81, 94, 0], [21, 22, 23, 45, 68, 13, 81, 95, 0]]] # pyformat: enable # locals().update({k.split('/')[0]:v for k,v in debug_tensors.items()}) # import ipdb; ipdb.set_trace() self.assertEqual(expected, topk)