def testGPipeSoftmaxLayerInputfromDecoder(self): with self.session(use_gpu=True): depth = 4 np.random.seed(6348575) p = GPipeTransformerLayer.Params() p.name = 'transformer' p.source_dim = depth p.has_aux_atten = True p.mask_self_atten = True p.tr_fflayer_tpl.hidden_dim = 7 p.tr_atten_tpl.num_attention_heads = 2 transformer = p.Instantiate() softmax = layers_with_gpipe.GPipeTransformerSoftmaxLayer.Params() softmax.name = 'softmax' softmax.inputs_from_decoder = True softmax.num_classes = 2 softmax.input_dim = depth softmax = softmax.Instantiate() (source_vecs, _, aux_vecs, aux_paddings, _, _) = self._testInputs(depth=depth) source_padding = tf.zeros([5, 2]) softmax_inputs = transformer.FPropDefaultTheta( aux_vecs, aux_paddings, source_vecs, source_padding, None, None, None, None) softmax_outputs = softmax.FPropDefaultTheta(*softmax_inputs) self.assertEqual([5, 2, 2], softmax_outputs.shape)
def testTransformerLayerExtendStep(self): with self.session(use_gpu=True) as sess: depth = 4 np.random.seed(6348575) p = GPipeTransformerLayer.Params() p.name = 'transformer' p.source_dim = depth p.has_aux_atten = True p.mask_self_atten = True p.tr_fflayer_tpl.hidden_dim = 7 p.tr_atten_tpl.num_attention_heads = 2 transformer = GPipeTransformerLayer(p) (source_vecs, _, aux_vecs, aux_paddings, input_tasks, tgt_tasks) = self._testInputs(depth=depth) source_padding = tf.zeros([5, 2]) output1 = transformer.FPropDefaultTheta(aux_vecs, aux_paddings, source_vecs, source_padding, None, None, input_tasks, tgt_tasks) h1 = output1[2] out_src_task, out_tgt_task = output1[-2], output1[-1] h2 = [] cached_source_vecs = tf.zeros([0, 2, 4]) cached_source_contexts = tf.zeros([0, 2, 4]) prefix_states = py_utils.NestedMap(key=cached_source_vecs, value=cached_source_contexts) for i in range(5): h, _, prefix_states = transformer.ExtendStep( transformer.theta, source_vecs[i, :, :], prefix_states, aux_vecs, aux_paddings) h2.append(h) h2 = tf.stack(h2) tf.global_variables_initializer().run() h1_v, h2_v = sess.run([h1, h2]) self.assertAllClose(h1_v, h2_v) self.assertAllClose(out_src_task, input_tasks) self.assertAllClose(out_tgt_task, tgt_tasks) self.assertAllClose( h1_v[2][1], [1.10429943, -1.64884555, 0.15726769, -0.00250494])