def testGPipeSoftmaxLayerInputfromDecoder(self):
     with self.session(use_gpu=True):
         depth = 4
         np.random.seed(6348575)
         p = GPipeTransformerLayer.Params()
         p.name = 'transformer'
         p.source_dim = depth
         p.has_aux_atten = True
         p.mask_self_atten = True
         p.tr_fflayer_tpl.hidden_dim = 7
         p.tr_atten_tpl.num_attention_heads = 2
         transformer = p.Instantiate()
         softmax = layers_with_gpipe.GPipeTransformerSoftmaxLayer.Params()
         softmax.name = 'softmax'
         softmax.inputs_from_decoder = True
         softmax.num_classes = 2
         softmax.input_dim = depth
         softmax = softmax.Instantiate()
         (source_vecs, _, aux_vecs, aux_paddings, _,
          _) = self._testInputs(depth=depth)
         source_padding = tf.zeros([5, 2])
         softmax_inputs = transformer.FPropDefaultTheta(
             aux_vecs, aux_paddings, source_vecs, source_padding, None,
             None, None, None)
         softmax_outputs = softmax.FPropDefaultTheta(*softmax_inputs)
         self.assertEqual([5, 2, 2], softmax_outputs.shape)
    def testTransformerLayerExtendStep(self):
        with self.session(use_gpu=True) as sess:
            depth = 4
            np.random.seed(6348575)
            p = GPipeTransformerLayer.Params()
            p.name = 'transformer'
            p.source_dim = depth
            p.has_aux_atten = True
            p.mask_self_atten = True
            p.tr_fflayer_tpl.hidden_dim = 7
            p.tr_atten_tpl.num_attention_heads = 2
            transformer = GPipeTransformerLayer(p)

            (source_vecs, _, aux_vecs, aux_paddings, input_tasks,
             tgt_tasks) = self._testInputs(depth=depth)
            source_padding = tf.zeros([5, 2])

            output1 = transformer.FPropDefaultTheta(aux_vecs, aux_paddings,
                                                    source_vecs,
                                                    source_padding, None, None,
                                                    input_tasks, tgt_tasks)
            h1 = output1[2]
            out_src_task, out_tgt_task = output1[-2], output1[-1]

            h2 = []
            cached_source_vecs = tf.zeros([0, 2, 4])
            cached_source_contexts = tf.zeros([0, 2, 4])
            prefix_states = py_utils.NestedMap(key=cached_source_vecs,
                                               value=cached_source_contexts)
            for i in range(5):
                h, _, prefix_states = transformer.ExtendStep(
                    transformer.theta, source_vecs[i, :, :], prefix_states,
                    aux_vecs, aux_paddings)
                h2.append(h)

            h2 = tf.stack(h2)

            tf.global_variables_initializer().run()
            h1_v, h2_v = sess.run([h1, h2])
            self.assertAllClose(h1_v, h2_v)
            self.assertAllClose(out_src_task, input_tasks)
            self.assertAllClose(out_tgt_task, tgt_tasks)
            self.assertAllClose(
                h1_v[2][1], [1.10429943, -1.64884555, 0.15726769, -0.00250494])