Exemplo n.º 1
0
 def test_build_share_rnn(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.Seq2SeqModel(check_feed_dict=False)
         opt = {
             'emb:vocab_size': 20,
             'emb:dim': 5,
             'cell:num_units': 10,
             'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell'
         }
         opt = {
             f'{n}:{k}': v
             for k, v in opt.items() for n in ('enc', 'dec')
         }
         opt['dec:logit:output_size'] = 2
         c = 0
         n = m.build_graph(opt=opt, name='q', **{'share:enc_dec_rnn': True})
         for v in tf.global_variables():
             self.assertNotEqual(v.name,
                                 'q/dec/rnn/basic_lstm_cell/kernel:0',
                                 'no decoder rnn')
             self.assertNotEqual(v.name, 'q/dec/rnn/basic_lstm_cell/bias:0',
                                 'no decoder rnn')
             if 'rnn/basic_lstm_cell/kernel:0' in v.name:
                 c += 1
         self.assertEqual(c, 1, 'only one cell')
Exemplo n.º 2
0
 def test_run_epoch(self):
     data = generator.read_seq2seq_data(self.gen(), self.vocab, self.vocab)
     batch_iter = partial(generator.seq2seq_batch_iter, *data,
                          batch_size=20, shuffle=True)
     with self.test_session(config=self.sess_config) as sess:
         m = tfm.Seq2SeqModel()
         n = m.build_graph(**{'dec:attn_enc_output': True})
         m.set_default_feed('dec.train_loss_denom', 20)
         return_ph = tf.placeholder(tf.float32, shape=(None, None), name='return')
         train_op = graph.create_train_op(m.training_loss, learning_rate=0.01)
         train_pg_op = graph.create_pg_train_op(m.nll, return_ph)
         return_feed_fn = partial(m.set_default_feed, return_ph)
         sess.run(tf.global_variables_initializer())
         run_fn = partial(run.run_epoch, sess, m, batch_iter)
         for __ in range(5):  # do some pre-train
             train_info = run_fn(train_op=train_op)
             print(train_info.summary())
         reward_fn = generator.reward_match_label
         eval_info = run.run_sampling_epoch(
             sess, m, batch_iter, greedy=True, reward_fn=reward_fn)
         print(eval_info.summary('eval'))
         run_fn = partial(run.run_sampling_epoch, sess, m, batch_iter,
                          reward_fn=reward_fn)
         for __ in range(10):
             train_info = run_fn(train_op=train_pg_op, return_feed_fn=return_feed_fn)
             print(train_info.summary())
         eval_info2 = run.run_sampling_epoch(
             sess, m, batch_iter, greedy=True, reward_fn=reward_fn)
         print(eval_info2.summary('eval'))
         self.assertLess(eval_info2.eval_loss, eval_info.eval_loss,
                         'after training, eval loss is lower.')
Exemplo n.º 3
0
 def test_build_share_emb_rnn_reuse(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.Seq2SeqModel(check_feed_dict=False)
         opt = {
             'emb:vocab_size': 20,
             'emb:dim': 5,
             'cell:num_units': 10,
             'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell'
         }
         opt = {
             f'{n}:{k}': v
             for k, v in opt.items() for n in ('enc', 'dec')
         }
         opt['dec:logit:output_size'] = 2
         c_rnn = 0
         c_emb = 0
         n = m.build_graph(opt=opt,
                           name='q',
                           **{
                               'share:enc_dec_emb': True,
                               'share:enc_dec_rnn': True
                           })
         for v in tf.global_variables():
             self.assertNotEqual(v.name,
                                 'q/dec/rnn/basic_lstm_cell/kernel:0',
                                 'no decoder rnn')
             self.assertNotEqual(v.name, 'q/dec/rnn/basic_lstm_cell/bias:0',
                                 'no decoder rnn')
             if 'rnn/basic_lstm_cell/kernel:0' in v.name:
                 c_rnn += 1
             self.assertNotEqual(v.name, 't/dec/embedding:0',
                                 'no decoder emb')
             if '/embedding:0' in v.name:
                 c_emb += 1
         self.assertEqual(c_emb, 1, 'only one embedding')
         self.assertEqual(c_rnn, 1, 'only one cell')
         num_vars = len(tf.global_variables())
         n = m.build_graph(opt=opt,
                           name='q',
                           reuse=True,
                           **{
                               'share:enc_dec_emb': True,
                               'share:enc_dec_rnn': True
                           })
         num_vars_ = len(tf.global_variables())
         self.assertEqual(num_vars, num_vars,
                          'no new variables when reuse is True')
Exemplo n.º 4
0
 def test_build_attn(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.Seq2SeqModel(check_feed_dict=False)
         opt = {
             'emb:vocab_size': 20,
             'emb:dim': 5,
             'cell:num_units': 10,
             'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell'
         }
         opt = {
             f'{n}:{k}': v
             for k, v in opt.items() for n in ('enc', 'dec')
         }
         opt['dec:logit:output_size'] = 2
         expected_vars = {
             'embedding:0': (20, 5),
             'rnn/basic_lstm_cell/kernel:0': (10 + 5, 10 * 4),
             'rnn/basic_lstm_cell/bias:0': (10 * 4, )
         }
         expected_vars = {
             f't/{n}/{k}': v
             for k, v in expected_vars.items() for n in ('enc', 'dec')
         }
         expected_vars.update({
             't/dec/logit_w:0': (2, 10),
             't/dec/logit_b:0': (2, ),
             't/dec/attention/dense/kernel:0': (20, 10),
             't/dec/attention/dense/bias:0': (10, )
         })
         n = m.build_graph(opt, name='t', **{'dec:attn_enc_output': True})
         for v in tf.global_variables():
             # print(f'{v.name}, {v.get_shape()}')
             self.assertTrue(v.name in expected_vars,
                             'expected variable scope/name')
             self.assertEqual(v.shape, expected_vars[v.name],
                              'shape is correct')
         for k, v in m._fetches.items():
             if k is not None:
                 self.assertNotEqual(v[0], v[1], 'fetch array is set')