def test_gru_decoder_teacher_force(self): vs = 9 h = 10 z = 2 bs = 5 es = 3 max_seq_length = 11 nl = 6 for cond_only_init in [True, False]: if cond_only_init: h = z tf.reset_default_graph() emb_matrix_VxE = tf.get_variable( 'emb', shape=[vs, es], initializer=tf.truncated_normal_initializer(stddev=0.01)) decoder = m.GruDecoder(h, vs, emb_matrix_VxE, 10, nl, cond_only_init=cond_only_init) # Test teacher_force shapes seq_lengths = tf.constant([3, 2, 4, 1, 2]) dec_out = decoder.teacher_force( tf.random.normal((bs, h)), # state tf.random.normal((bs, max_seq_length, es)), # dec_inputs seq_lengths) # lengths with self.session() as ss: ss.run(tf.initializers.global_variables()) dec_out_np = ss.run(dec_out) self.assertEqual((bs, max_seq_length, vs), dec_out_np.shape)
def test_gru_decoder_decode_v(self): vs = 9 h = 10 z = 2 bs = 5 es = 3 nl = 6 k = 3 alpha = 0.6 for cond_only_init in [True, False]: if cond_only_init: h = z tf.reset_default_graph() emb_matrix_VxE = tf.get_variable( 'emb', shape=[vs, es], initializer=tf.truncated_normal_initializer(stddev=0.01)) decoder = m.GruDecoder(h, vs, emb_matrix_VxE, 10, nl, cond_only_init=cond_only_init) symb = decoder.decode_v(tf.random.normal((bs, h)), method='argmax') symbr = decoder.decode_v(tf.random.normal((bs, h)), method='random') symbb = decoder.decode_v(tf.random.normal((bs, h)), method='beam', first_token=0, beam_size=k, alpha=alpha) with self.session() as ss: ss.run(tf.initializers.global_variables()) symb_np, symbr_np, symbb_np = ss.run([symb, symbr, symbb]) self.assertEqual(bs, symb_np.shape[0]) self.assertEqual(bs, symbr_np.shape[0]) self.assertEqual(bs, symbb_np.shape[0]) # Check symbols are within vocab, i.e. all following statemes are false. self.assertIntsInRange(symb_np, 0, vs) self.assertIntsInRange(symbr_np, 0, vs) self.assertIntsInRange(symbb_np, 0, vs)
def test_gru_decoder_decode_v_gumbel(self): vs = 9 h = 10 z = 2 bs = 5 es = 3 nl = 6 for cond_only_init in [True, False]: if cond_only_init: h = z tf.reset_default_graph() emb_matrix_VxE = tf.get_variable( 'emb', shape=[vs, es], initializer=tf.truncated_normal_initializer(stddev=0.01)) decoder = m.GruDecoder(h, vs, emb_matrix_VxE, 10, nl, cond_only_init=cond_only_init) symb_BxM, symb_emb_BxMxE = decoder.decode_v_gumbel( tf.random.normal((bs, h))) with self.session() as ss: ss.run(tf.initializers.global_variables()) symb_np, symbe_np = ss.run([symb_BxM, symb_emb_BxMxE]) # pylint: disable=g-generic-assert self.assertEqual(2, len(symb_np.shape)) self.assertEqual(3, len(symbe_np.shape)) self.assertEqual(bs, symb_np.shape[0]) self.assertEqual(bs, symbe_np.shape[0]) self.assertEqual(es, symbe_np.shape[2]) # Check symbols are within vocab, i.e. all following stateme are false. self.assertIntsInRange(symb_np, 0, vs)