def test_build(self): with self.test_session(config=self.sess_config) as sess: m = model.SeqModel(check_feed_dict=False) opt = { 'emb:vocab_size': 20, 'emb:dim': 5, 'cell:num_units': 10, 'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell', 'logit:output_size': 2 } expected_vars = { 't/embedding:0': (20, 5), 't/rnn/basic_lstm_cell/kernel:0': (10 + 5, 10 * 4), 't/rnn/basic_lstm_cell/bias:0': (10 * 4, ), 't/logit_w:0': (2, 10), 't/logit_b:0': (2, ) } n = m.build_graph(opt, name='t') for v in tf.global_variables(): self.assertTrue(v.name in expected_vars, 'expected variable scope/name') self.assertEqual(v.shape, expected_vars[v.name], 'shape is correct') for k, v in m._fetches.items(): if k is not None: self.assertNotEqual(v[0], v[1], 'fetch array is set')
def test_train(self): logger = util.get_logger(level='warning') data = generator.read_seq_data( self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20) batch_iter = partial(generator.seq_batch_iter, *data, batch_size=13, shuffle=True, keep_sentence=False) with self.test_session(config=self.sess_config) as sess: m = tfm.SeqModel() n = m.build_graph({'rnn:fn': 'seqmodel.graph.scan_rnn', 'cell:num_layers': 2}) m.set_default_feed('train_loss_denom', 13) optimizer = tf.train.AdamOptimizer() train_op = optimizer.minimize(m.training_loss) sess.run(tf.global_variables_initializer()) train_run_fn = partial(run.run_epoch, sess, m, batch_iter, train_op=train_op) eval_run_fn = partial(run.run_epoch, sess, m, batch_iter) def stop_early(*args, **kwargs): pass eval_info = eval_run_fn() train_state = run.train(train_run_fn, logger, max_epoch=3, valid_run_epoch_fn=eval_run_fn, end_epoch_fn=stop_early) self.assertLess(train_state.best_eval, eval_info.eval_loss, 'after training, eval loss is lower.') self.assertEqual(train_state.cur_epoch, 3, 'train for max epoch')
def test_run_epoch(self): data = generator.read_seq_data(self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20) batch_iter = partial(generator.seq_batch_iter, *data, batch_size=13, shuffle=True, keep_sentence=False) with self.test_session(config=self.sess_config) as sess: m = tfm.SeqModel() n = m.build_graph() optimizer = tf.train.AdamOptimizer() train_op = optimizer.minimize(m.training_loss) sess.run(tf.global_variables_initializer()) run_fn = partial(run.run_epoch, sess, m, batch_iter) eval_info = run_fn() self.assertEqual(eval_info.num_tokens, self.num_lines + self.num_tokens, 'run uses all tokens') self.assertAlmostEqual(eval_info.eval_loss, np.log(self.vocab.vocab_size), places=1, msg='eval loss is close to uniform.') for __ in range(3): train_info = run_fn(train_op=train_op) self.assertLess(train_info.eval_loss, eval_info.eval_loss, 'after training, eval loss is lower.')
def test_build_reuse(self): with self.test_session(config=self.sess_config) as sess: m = model.SeqModel(check_feed_dict=False) n = m.build_graph() num_vars = len(tf.global_variables()) n = m.build_graph(reuse=True, no_dropout=True) self.assertEqual(type(n['cell']), tf.nn.rnn_cell.BasicLSTMCell, 'overwrite default options with kwargs') num_vars_ = len(tf.global_variables()) self.assertEqual(num_vars, num_vars, 'no new variables when reuse is True')
def test_build_no_output(self): with self.test_session(config=self.sess_config) as sess: m = model.SeqModel(check_feed_dict=False) n = m.build_graph(**{'out:logit': False, 'out:loss': False}) self.assertTrue('logit' not in n, 'logit is not in nodes') for k, v in m._fetches.items(): if k is None or k == m._TRAIN_ or k == m._EVAL_: self.assertEqual(v[0], v[1], 'fetch array is not set') if k == m._PREDICT_: self.assertNotEqual(v[0], v[1], 'predict fetch array is set') self.assertRaises(ValueError, m.build_graph, **{'out:logit': False})
def test_build_overwrite_opt(self): with self.test_session(config=self.sess_config) as sess: m = model.SeqModel(check_feed_dict=False) opt = { 'emb:vocab_size': 20, 'cell:in_keep_prob': 0.5, 'logit:output_size': 2 } n = m.build_graph(opt) self.assertEqual(n['emb_vars'].get_shape()[0], 20, 'overwrite default options') self.assertEqual(type(n['cell']), tf.nn.rnn_cell.DropoutWrapper, 'overwrite default options') n = m.build_graph(opt, reuse=True, **{'cell:in_keep_prob': 1.0}) self.assertEqual(type(n['cell']), tf.nn.rnn_cell.BasicLSTMCell, 'overwrite default options with kwargs')
def test_build_decode(self): with self.test_session(config=self.sess_config) as sess: m = model.SeqModel(check_feed_dict=False) n = m.build_graph( **{ 'out:logit': True, 'out:decode': True, 'decode:add_greedy': True, 'decode:add_sampling': True }) self.assertTrue('decode_greedy' in n, 'decode_greedy in nodes') self.assertTrue('decode_greedy' in m._predict_fetch, 'decode_greedy in predict dict') self.assertTrue('decode_sampling' in n, 'decode_sampling in nodes') self.assertTrue('decode_sampling' in m._predict_fetch, 'decode_sampling in predict dict')