def test_train(self): logger = util.get_logger(level='warning') data = generator.read_seq_data( self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20) batch_iter = partial(generator.seq_batch_iter, *data, batch_size=13, shuffle=True, keep_sentence=False) with self.test_session(config=self.sess_config) as sess: m = tfm.SeqModel() n = m.build_graph({'rnn:fn': 'seqmodel.graph.scan_rnn', 'cell:num_layers': 2}) m.set_default_feed('train_loss_denom', 13) optimizer = tf.train.AdamOptimizer() train_op = optimizer.minimize(m.training_loss) sess.run(tf.global_variables_initializer()) train_run_fn = partial(run.run_epoch, sess, m, batch_iter, train_op=train_op) eval_run_fn = partial(run.run_epoch, sess, m, batch_iter) def stop_early(*args, **kwargs): pass eval_info = eval_run_fn() train_state = run.train(train_run_fn, logger, max_epoch=3, valid_run_epoch_fn=eval_run_fn, end_epoch_fn=stop_early) self.assertLess(train_state.best_eval, eval_info.eval_loss, 'after training, eval loss is lower.') self.assertEqual(train_state.cur_epoch, 3, 'train for max epoch')
def test_run_epoch(self): data = generator.read_seq_data(self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20) batch_iter = partial(generator.seq_batch_iter, *data, batch_size=13, shuffle=True, keep_sentence=False) with self.test_session(config=self.sess_config) as sess: m = tfm.SeqModel() n = m.build_graph() optimizer = tf.train.AdamOptimizer() train_op = optimizer.minimize(m.training_loss) sess.run(tf.global_variables_initializer()) run_fn = partial(run.run_epoch, sess, m, batch_iter) eval_info = run_fn() self.assertEqual(eval_info.num_tokens, self.num_lines + self.num_tokens, 'run uses all tokens') self.assertAlmostEqual(eval_info.eval_loss, np.log(self.vocab.vocab_size), places=1, msg='eval loss is close to uniform.') for __ in range(3): train_info = run_fn(train_op=train_op) self.assertLess(train_info.eval_loss, eval_info.eval_loss, 'after training, eval loss is lower.')
def test_read_seq_data(self): x, y = generator.read_seq_data(self.gen(), self.vocab, self.vocab, keep_sentence=True) self.assertEqual(len(x), self.num_lines, 'number of sequences') self.assertEqual(len(y), self.num_lines, 'number of sequences') for x_, y_ in zip(x, y): self.assertEqual(x_[1:], y_[:-1], 'output is shifted input')
def test_read_seq_data_sen(self): x, y = generator.read_seq_data(self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20) num_seq = (self.num_lines + self.num_tokens) // 20 if (self.num_lines + self.num_tokens) % 20 > 1: num_seq += 1 self.assertEqual(len(x), num_seq, 'number of sequences') self.assertEqual(len(y), num_seq, 'number of sequences') for x_, y_ in zip(x, y): self.assertEqual(x_[1:], y_[:-1], 'output is shifted input')
def test_seq_batch_iter(self): data = generator.read_seq_data(self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20) count = 0 for batch in generator.seq_batch_iter(*data, batch_size=13, shuffle=False, keep_sentence=False): count += batch.num_tokens self.assertTrue(batch.keep_state, 'keep_state is True') self.assertEqual(batch.num_tokens, sum(batch.features.seq_len), 'num_tokens is sum of seq_len') self.assertEqual(count, self.num_lines + self.num_tokens, 'number of tokens (including eos symbol)')