Exemplo n.º 1
0
 def testEvolvedTransformerDecoderLayerConstruction(self):
     p = GPipeEvolvedTransformerDecoderLayer.Params()
     p.name = 'gpipe_evolved_transformer_decoder'
     p.source_dim = 16
     p.transformer_tpl.tr_atten_tpl.num_attention_heads = 2
     p.has_aux_atten = True
     p.mask_self_atten = True
     _ = GPipeEvolvedTransformerDecoderLayer(p)
Exemplo n.º 2
0
    def testEvolvedTransformerDecoderLayerExtendStep(self):
        with self.session(use_gpu=True) as sess:
            np.random.seed(6348575)
            depth = 4
            p = GPipeEvolvedTransformerDecoderLayer.Params()
            p.name = 'gpipe_evolved_transformer_decoder'
            p.source_dim = depth
            p.has_aux_atten = True
            p.mask_self_atten = True
            p.tr_double_heads_atten_tpl.num_attention_heads = 2
            p.tr_atten_tpl.num_attention_heads = 2
            p.transformer_tpl.tr_atten_tpl.num_attention_heads = 2
            et_decoder = GPipeEvolvedTransformerDecoderLayer(p)

            (source_vecs, _, aux_vecs,
             aux_paddings) = self._testInputs(depth=depth)
            source_padding = tf.zeros([5, 2])

            h1 = et_decoder.FPropDefaultTheta(
                aux_vecs,
                aux_paddings,
                source_vecs,
                source_padding,
                None,
                None,
                None,
                None,
            )[2]
            h2 = []

            double_head_attention_states = py_utils.NestedMap(
                key=tf.zeros([0, 2, 4]), value=tf.zeros([0, 2, 4]))
            transformer_layer_states = py_utils.NestedMap(
                key=tf.zeros([0, 2, 4]), value=tf.zeros([0, 2, 4]))
            branched_convs_input = tf.zeros([0, 2, 4])

            prefix_states = py_utils.NestedMap(
                double_head_attention_states=double_head_attention_states,
                transformer_layer_states=transformer_layer_states,
                branched_convs_input=branched_convs_input)

            for i in range(5):
                h, _, prefix_states = et_decoder.ExtendStep(
                    et_decoder.theta, source_vecs[i, :, :], prefix_states,
                    aux_vecs, aux_paddings)
                h2.append(h)

            h2 = tf.stack(h2)

            tf.global_variables_initializer().run()
            h1_v, h2_v = sess.run([h1, h2])
            self.assertAllClose(h1_v, h2_v)
Exemplo n.º 3
0
    def testEvolvedTransformerDecoderLayerFProp(self):
        with self.session(use_gpu=True) as sess:
            np.random.seed(6348575)
            depth = 4
            p = GPipeEvolvedTransformerDecoderLayer.Params()
            p.name = 'gpipe_evolved_transformer_decoder'
            p.source_dim = depth
            p.has_aux_atten = True
            p.mask_self_atten = True
            p.tr_double_heads_atten_tpl.num_attention_heads = 2
            p.tr_atten_tpl.num_attention_heads = 2
            p.transformer_tpl.tr_atten_tpl.num_attention_heads = 2
            transformer = GPipeEvolvedTransformerDecoderLayer(p)

            (source_vecs, source_padding, aux_vecs, aux_paddings, _,
             _) = self._testInputs(depth=depth)

            output = transformer.FPropDefaultTheta(aux_vecs, aux_paddings,
                                                   source_vecs, source_padding,
                                                   None, None, None, None,
                                                   None, None)
            h = output[0]

            tf.global_variables_initializer().run()
            actual_layer_output = sess.run([h])[0]
            tf.logging.info(np.array_repr(actual_layer_output))
            # pylint: disable=bad-whitespace
            # pyformat: disable
            expected_layer_output = [
                [[0.5904724, 0.05267439, 0.89581013, 0.63010913],
                 [0.79584485, 0.07670615, 0.40381077, 0.26504567]],
                [[0.35448784, 0.28477612, 0.05394353, 0.06531866],
                 [0.44413447, 0.81940264, 0.98786688, 0.35846332]],
                [[0.66811442, 0.07942203, 0.56781054, 0.83598584],
                 [0.45858502, 0.44949403, 0.06522893, 0.10947803]],
                [[0.58166796, 0.94657594, 0.17643142, 0.02062288],
                 [0.40596515, 0.01996579, 0.93727112, 0.97478259]],
                [[0.34873158, 0.0095871, 0.34063059, 0.64620447],
                 [0.70584863, 0.69263214, 0.38247514, 0.28985959]],
                [[0.66496903, 0.20383522, 0.35497066, 0.66646087],
                 [0.0787568, 0.26172587, 0.23034802, 0.88751978]],
                [[0.68153989, 0.81061888, 0.90142977, 0.87612331],
                 [0.15129775, 0.56084079, 0.87029755, 0.37908044]]
            ]
            # pyformat: enable
            # pylint: enable=bad-whitespace
            self.assertAllClose(expected_layer_output, actual_layer_output)