コード例 #1
0
 def test_build(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.SeqModel(check_feed_dict=False)
         opt = {
             'emb:vocab_size': 20,
             'emb:dim': 5,
             'cell:num_units': 10,
             'cell:cell_class': 'tensorflow.nn.rnn_cell.BasicLSTMCell',
             'logit:output_size': 2
         }
         expected_vars = {
             't/embedding:0': (20, 5),
             't/rnn/basic_lstm_cell/kernel:0': (10 + 5, 10 * 4),
             't/rnn/basic_lstm_cell/bias:0': (10 * 4, ),
             't/logit_w:0': (2, 10),
             't/logit_b:0': (2, )
         }
         n = m.build_graph(opt, name='t')
         for v in tf.global_variables():
             self.assertTrue(v.name in expected_vars,
                             'expected variable scope/name')
             self.assertEqual(v.shape, expected_vars[v.name],
                              'shape is correct')
         for k, v in m._fetches.items():
             if k is not None:
                 self.assertNotEqual(v[0], v[1], 'fetch array is set')
コード例 #2
0
    def test_train(self):
        logger = util.get_logger(level='warning')
        data = generator.read_seq_data(
            self.gen(), self.vocab, self.vocab, keep_sentence=False, seq_len=20)
        batch_iter = partial(generator.seq_batch_iter, *data,
                             batch_size=13, shuffle=True, keep_sentence=False)
        with self.test_session(config=self.sess_config) as sess:
            m = tfm.SeqModel()
            n = m.build_graph({'rnn:fn': 'seqmodel.graph.scan_rnn',
                               'cell:num_layers': 2})
            m.set_default_feed('train_loss_denom', 13)
            optimizer = tf.train.AdamOptimizer()
            train_op = optimizer.minimize(m.training_loss)
            sess.run(tf.global_variables_initializer())
            train_run_fn = partial(run.run_epoch, sess, m, batch_iter, train_op=train_op)
            eval_run_fn = partial(run.run_epoch, sess, m, batch_iter)

            def stop_early(*args, **kwargs):
                pass

            eval_info = eval_run_fn()
            train_state = run.train(train_run_fn, logger, max_epoch=3,
                                    valid_run_epoch_fn=eval_run_fn,
                                    end_epoch_fn=stop_early)
            self.assertLess(train_state.best_eval, eval_info.eval_loss,
                            'after training, eval loss is lower.')
            self.assertEqual(train_state.cur_epoch, 3, 'train for max epoch')
コード例 #3
0
 def test_run_epoch(self):
     data = generator.read_seq_data(self.gen(),
                                    self.vocab,
                                    self.vocab,
                                    keep_sentence=False,
                                    seq_len=20)
     batch_iter = partial(generator.seq_batch_iter,
                          *data,
                          batch_size=13,
                          shuffle=True,
                          keep_sentence=False)
     with self.test_session(config=self.sess_config) as sess:
         m = tfm.SeqModel()
         n = m.build_graph()
         optimizer = tf.train.AdamOptimizer()
         train_op = optimizer.minimize(m.training_loss)
         sess.run(tf.global_variables_initializer())
         run_fn = partial(run.run_epoch, sess, m, batch_iter)
         eval_info = run_fn()
         self.assertEqual(eval_info.num_tokens,
                          self.num_lines + self.num_tokens,
                          'run uses all tokens')
         self.assertAlmostEqual(eval_info.eval_loss,
                                np.log(self.vocab.vocab_size),
                                places=1,
                                msg='eval loss is close to uniform.')
         for __ in range(3):
             train_info = run_fn(train_op=train_op)
         self.assertLess(train_info.eval_loss, eval_info.eval_loss,
                         'after training, eval loss is lower.')
コード例 #4
0
 def test_build_reuse(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.SeqModel(check_feed_dict=False)
         n = m.build_graph()
         num_vars = len(tf.global_variables())
         n = m.build_graph(reuse=True, no_dropout=True)
         self.assertEqual(type(n['cell']), tf.nn.rnn_cell.BasicLSTMCell,
                          'overwrite default options with kwargs')
         num_vars_ = len(tf.global_variables())
         self.assertEqual(num_vars, num_vars,
                          'no new variables when reuse is True')
コード例 #5
0
 def test_build_no_output(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.SeqModel(check_feed_dict=False)
         n = m.build_graph(**{'out:logit': False, 'out:loss': False})
         self.assertTrue('logit' not in n, 'logit is not in nodes')
         for k, v in m._fetches.items():
             if k is None or k == m._TRAIN_ or k == m._EVAL_:
                 self.assertEqual(v[0], v[1], 'fetch array is not set')
             if k == m._PREDICT_:
                 self.assertNotEqual(v[0], v[1],
                                     'predict fetch array is set')
         self.assertRaises(ValueError, m.build_graph,
                           **{'out:logit': False})
コード例 #6
0
 def test_build_overwrite_opt(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.SeqModel(check_feed_dict=False)
         opt = {
             'emb:vocab_size': 20,
             'cell:in_keep_prob': 0.5,
             'logit:output_size': 2
         }
         n = m.build_graph(opt)
         self.assertEqual(n['emb_vars'].get_shape()[0], 20,
                          'overwrite default options')
         self.assertEqual(type(n['cell']), tf.nn.rnn_cell.DropoutWrapper,
                          'overwrite default options')
         n = m.build_graph(opt, reuse=True, **{'cell:in_keep_prob': 1.0})
         self.assertEqual(type(n['cell']), tf.nn.rnn_cell.BasicLSTMCell,
                          'overwrite default options with kwargs')
コード例 #7
0
 def test_build_decode(self):
     with self.test_session(config=self.sess_config) as sess:
         m = model.SeqModel(check_feed_dict=False)
         n = m.build_graph(
             **{
                 'out:logit': True,
                 'out:decode': True,
                 'decode:add_greedy': True,
                 'decode:add_sampling': True
             })
         self.assertTrue('decode_greedy' in n, 'decode_greedy in nodes')
         self.assertTrue('decode_greedy' in m._predict_fetch,
                         'decode_greedy in predict dict')
         self.assertTrue('decode_sampling' in n, 'decode_sampling in nodes')
         self.assertTrue('decode_sampling' in m._predict_fetch,
                         'decode_sampling in predict dict')