示例#1
0
    def test_train(self):
        logger = util.get_logger(level='warning')
        data = generator.read_seq_data(
            self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20)
        batch_iter = partial(generator.seq_batch_iter, *data,
                             batch_size=13, shuffle=True, keep_sentence=False)
        with self.test_session(config=self.sess_config) as sess:
            m = tfm.SeqModel()
            n = m.build_graph({'rnn:fn': 'seqmodel.graph.scan_rnn',
                               'cell:num_layers': 2})
            m.set_default_feed('train_loss_denom', 13)
            optimizer = tf.train.AdamOptimizer()
            train_op = optimizer.minimize(m.training_loss)
            sess.run(tf.global_variables_initializer())
            train_run_fn = partial(run.run_epoch, sess, m, batch_iter, train_op=train_op)
            eval_run_fn = partial(run.run_epoch, sess, m, batch_iter)

            def stop_early(*args, **kwargs):
                pass

            eval_info = eval_run_fn()
            train_state = run.train(train_run_fn, logger, max_epoch=3,
                                    valid_run_epoch_fn=eval_run_fn,
                                    end_epoch_fn=stop_early)
            self.assertLess(train_state.best_eval, eval_info.eval_loss,
                            'after training, eval loss is lower.')
            self.assertEqual(train_state.cur_epoch, 3, 'train for max epoch')
示例#2
0
 def test_run_epoch(self):
     data = generator.read_seq_data(self.gen(),
                                    self.vocab,
                                    self.vocab,
                                    keep_sentence=False,
                                    seq_len=20)
     batch_iter = partial(generator.seq_batch_iter,
                          *data,
                          batch_size=13,
                          shuffle=True,
                          keep_sentence=False)
     with self.test_session(config=self.sess_config) as sess:
         m = tfm.SeqModel()
         n = m.build_graph()
         optimizer = tf.train.AdamOptimizer()
         train_op = optimizer.minimize(m.training_loss)
         sess.run(tf.global_variables_initializer())
         run_fn = partial(run.run_epoch, sess, m, batch_iter)
         eval_info = run_fn()
         self.assertEqual(eval_info.num_tokens,
                          self.num_lines + self.num_tokens,
                          'run uses all tokens')
         self.assertAlmostEqual(eval_info.eval_loss,
                                np.log(self.vocab.vocab_size),
                                places=1,
                                msg='eval loss is close to uniform.')
         for __ in range(3):
             train_info = run_fn(train_op=train_op)
         self.assertLess(train_info.eval_loss, eval_info.eval_loss,
                         'after training, eval loss is lower.')
示例#3
0
 def test_read_seq_data(self):
     x, y = generator.read_seq_data(self.gen(),
                                    self.vocab,
                                    self.vocab,
                                    keep_sentence=True)
     self.assertEqual(len(x), self.num_lines, 'number of sequences')
     self.assertEqual(len(y), self.num_lines, 'number of sequences')
     for x_, y_ in zip(x, y):
         self.assertEqual(x_[1:], y_[:-1], 'output is shifted input')
示例#4
0
 def test_read_seq_data_sen(self):
     x, y = generator.read_seq_data(self.gen(),
                                    self.vocab,
                                    self.vocab,
                                    keep_sentence=False,
                                    seq_len=20)
     num_seq = (self.num_lines + self.num_tokens) // 20
     if (self.num_lines + self.num_tokens) % 20 > 1:
         num_seq += 1
     self.assertEqual(len(x), num_seq, 'number of sequences')
     self.assertEqual(len(y), num_seq, 'number of sequences')
     for x_, y_ in zip(x, y):
         self.assertEqual(x_[1:], y_[:-1], 'output is shifted input')
示例#5
0
 def test_seq_batch_iter(self):
     data = generator.read_seq_data(self.gen(),
                                    self.vocab,
                                    self.vocab,
                                    keep_sentence=False,
                                    seq_len=20)
     count = 0
     for batch in generator.seq_batch_iter(*data,
                                           batch_size=13,
                                           shuffle=False,
                                           keep_sentence=False):
         count += batch.num_tokens
         self.assertTrue(batch.keep_state, 'keep_state is True')
         self.assertEqual(batch.num_tokens, sum(batch.features.seq_len),
                          'num_tokens is sum of seq_len')
     self.assertEqual(count, self.num_lines + self.num_tokens,
                      'number of tokens (including eos symbol)')