def test_build_share_rnn(self): with self.test_session(config=self.sess_config) as sess: m = model.Seq2SeqModel(check_feed_dict=False) opt = { 'emb:vocab_size': 20, 'emb:dim': 5, 'cell:num_units': 10, 'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell' } opt = { f'{n}:{k}': v for k, v in opt.items() for n in ('enc', 'dec') } opt['dec:logit:output_size'] = 2 c = 0 n = m.build_graph(opt=opt, name='q', **{'share:enc_dec_rnn': True}) for v in tf.global_variables(): self.assertNotEqual(v.name, 'q/dec/rnn/basic_lstm_cell/kernel:0', 'no decoder rnn') self.assertNotEqual(v.name, 'q/dec/rnn/basic_lstm_cell/bias:0', 'no decoder rnn') if 'rnn/basic_lstm_cell/kernel:0' in v.name: c += 1 self.assertEqual(c, 1, 'only one cell')
def test_run_epoch(self): data = generator.read_seq2seq_data(self.gen(), self.vocab, self.vocab) batch_iter = partial(generator.seq2seq_batch_iter, *data, batch_size=20, shuffle=True) with self.test_session(config=self.sess_config) as sess: m = tfm.Seq2SeqModel() n = m.build_graph(**{'dec:attn_enc_output': True}) m.set_default_feed('dec.train_loss_denom', 20) return_ph = tf.placeholder(tf.float32, shape=(None, None), name='return') train_op = graph.create_train_op(m.training_loss, learning_rate=0.01) train_pg_op = graph.create_pg_train_op(m.nll, return_ph) return_feed_fn = partial(m.set_default_feed, return_ph) sess.run(tf.global_variables_initializer()) run_fn = partial(run.run_epoch, sess, m, batch_iter) for __ in range(5): # do some pre-train train_info = run_fn(train_op=train_op) print(train_info.summary()) reward_fn = generator.reward_match_label eval_info = run.run_sampling_epoch( sess, m, batch_iter, greedy=True, reward_fn=reward_fn) print(eval_info.summary('eval')) run_fn = partial(run.run_sampling_epoch, sess, m, batch_iter, reward_fn=reward_fn) for __ in range(10): train_info = run_fn(train_op=train_pg_op, return_feed_fn=return_feed_fn) print(train_info.summary()) eval_info2 = run.run_sampling_epoch( sess, m, batch_iter, greedy=True, reward_fn=reward_fn) print(eval_info2.summary('eval')) self.assertLess(eval_info2.eval_loss, eval_info.eval_loss, 'after training, eval loss is lower.')
def test_build_share_emb_rnn_reuse(self): with self.test_session(config=self.sess_config) as sess: m = model.Seq2SeqModel(check_feed_dict=False) opt = { 'emb:vocab_size': 20, 'emb:dim': 5, 'cell:num_units': 10, 'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell' } opt = { f'{n}:{k}': v for k, v in opt.items() for n in ('enc', 'dec') } opt['dec:logit:output_size'] = 2 c_rnn = 0 c_emb = 0 n = m.build_graph(opt=opt, name='q', **{ 'share:enc_dec_emb': True, 'share:enc_dec_rnn': True }) for v in tf.global_variables(): self.assertNotEqual(v.name, 'q/dec/rnn/basic_lstm_cell/kernel:0', 'no decoder rnn') self.assertNotEqual(v.name, 'q/dec/rnn/basic_lstm_cell/bias:0', 'no decoder rnn') if 'rnn/basic_lstm_cell/kernel:0' in v.name: c_rnn += 1 self.assertNotEqual(v.name, 't/dec/embedding:0', 'no decoder emb') if '/embedding:0' in v.name: c_emb += 1 self.assertEqual(c_emb, 1, 'only one embedding') self.assertEqual(c_rnn, 1, 'only one cell') num_vars = len(tf.global_variables()) n = m.build_graph(opt=opt, name='q', reuse=True, **{ 'share:enc_dec_emb': True, 'share:enc_dec_rnn': True }) num_vars_ = len(tf.global_variables()) self.assertEqual(num_vars, num_vars, 'no new variables when reuse is True')
def test_build_attn(self): with self.test_session(config=self.sess_config) as sess: m = model.Seq2SeqModel(check_feed_dict=False) opt = { 'emb:vocab_size': 20, 'emb:dim': 5, 'cell:num_units': 10, 'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell' } opt = { f'{n}:{k}': v for k, v in opt.items() for n in ('enc', 'dec') } opt['dec:logit:output_size'] = 2 expected_vars = { 'embedding:0': (20, 5), 'rnn/basic_lstm_cell/kernel:0': (10 + 5, 10 * 4), 'rnn/basic_lstm_cell/bias:0': (10 * 4, ) } expected_vars = { f't/{n}/{k}': v for k, v in expected_vars.items() for n in ('enc', 'dec') } expected_vars.update({ 't/dec/logit_w:0': (2, 10), 't/dec/logit_b:0': (2, ), 't/dec/attention/dense/kernel:0': (20, 10), 't/dec/attention/dense/bias:0': (10, ) }) n = m.build_graph(opt, name='t', **{'dec:attn_enc_output': True}) for v in tf.global_variables(): # print(f'{v.name}, {v.get_shape()}') self.assertTrue(v.name in expected_vars, 'expected variable scope/name') self.assertEqual(v.shape, expected_vars[v.name], 'shape is correct') for k, v in m._fetches.items(): if k is not None: self.assertNotEqual(v[0], v[1], 'fetch array is set')