Ejemplo n.º 1
0
    def test_gru_decoder_teacher_force(self):
        vs = 9
        h = 10
        z = 2
        bs = 5
        es = 3
        max_seq_length = 11
        nl = 6

        for cond_only_init in [True, False]:
            if cond_only_init:
                h = z
            tf.reset_default_graph()
            emb_matrix_VxE = tf.get_variable(
                'emb',
                shape=[vs, es],
                initializer=tf.truncated_normal_initializer(stddev=0.01))
            decoder = m.GruDecoder(h,
                                   vs,
                                   emb_matrix_VxE,
                                   10,
                                   nl,
                                   cond_only_init=cond_only_init)
            # Test teacher_force shapes
            seq_lengths = tf.constant([3, 2, 4, 1, 2])
            dec_out = decoder.teacher_force(
                tf.random.normal((bs, h)),  # state
                tf.random.normal((bs, max_seq_length, es)),  # dec_inputs
                seq_lengths)  # lengths

            with self.session() as ss:
                ss.run(tf.initializers.global_variables())
                dec_out_np = ss.run(dec_out)

            self.assertEqual((bs, max_seq_length, vs), dec_out_np.shape)
Ejemplo n.º 2
0
    def test_gru_decoder_decode_v(self):
        vs = 9
        h = 10
        z = 2
        bs = 5
        es = 3
        nl = 6
        k = 3
        alpha = 0.6

        for cond_only_init in [True, False]:
            if cond_only_init:
                h = z
            tf.reset_default_graph()
            emb_matrix_VxE = tf.get_variable(
                'emb',
                shape=[vs, es],
                initializer=tf.truncated_normal_initializer(stddev=0.01))
            decoder = m.GruDecoder(h,
                                   vs,
                                   emb_matrix_VxE,
                                   10,
                                   nl,
                                   cond_only_init=cond_only_init)

            symb = decoder.decode_v(tf.random.normal((bs, h)), method='argmax')
            symbr = decoder.decode_v(tf.random.normal((bs, h)),
                                     method='random')
            symbb = decoder.decode_v(tf.random.normal((bs, h)),
                                     method='beam',
                                     first_token=0,
                                     beam_size=k,
                                     alpha=alpha)

            with self.session() as ss:
                ss.run(tf.initializers.global_variables())
                symb_np, symbr_np, symbb_np = ss.run([symb, symbr, symbb])

            self.assertEqual(bs, symb_np.shape[0])
            self.assertEqual(bs, symbr_np.shape[0])
            self.assertEqual(bs, symbb_np.shape[0])

            # Check symbols are within vocab, i.e. all following statemes are false.
            self.assertIntsInRange(symb_np, 0, vs)
            self.assertIntsInRange(symbr_np, 0, vs)
            self.assertIntsInRange(symbb_np, 0, vs)
Ejemplo n.º 3
0
    def test_gru_decoder_decode_v_gumbel(self):
        vs = 9
        h = 10
        z = 2
        bs = 5
        es = 3
        nl = 6

        for cond_only_init in [True, False]:
            if cond_only_init:
                h = z
            tf.reset_default_graph()
            emb_matrix_VxE = tf.get_variable(
                'emb',
                shape=[vs, es],
                initializer=tf.truncated_normal_initializer(stddev=0.01))
            decoder = m.GruDecoder(h,
                                   vs,
                                   emb_matrix_VxE,
                                   10,
                                   nl,
                                   cond_only_init=cond_only_init)

            symb_BxM, symb_emb_BxMxE = decoder.decode_v_gumbel(
                tf.random.normal((bs, h)))

            with self.session() as ss:
                ss.run(tf.initializers.global_variables())
                symb_np, symbe_np = ss.run([symb_BxM, symb_emb_BxMxE])

            # pylint: disable=g-generic-assert
            self.assertEqual(2, len(symb_np.shape))
            self.assertEqual(3, len(symbe_np.shape))
            self.assertEqual(bs, symb_np.shape[0])
            self.assertEqual(bs, symbe_np.shape[0])
            self.assertEqual(es, symbe_np.shape[2])

            # Check symbols are within vocab, i.e. all following stateme are false.
            self.assertIntsInRange(symb_np, 0, vs)