示例#1
0
  def testDynamicRnnTrainingLoop(self):
    graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    metrics = list(basic_rnn_train.training_loop(
        graph, self.train_dir, num_training_steps=5, summary_frequency=1))

    for metric in metrics:
      self.assertTrue(metric['loss'] >= 0)
      self.assertTrue(metric['accuracy'] >= 0)
示例#2
0
  def testEvalLoop(self):
    train_graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    list(basic_rnn_train.training_loop(
        train_graph, self.eval_dir, num_training_steps=5, summary_frequency=1))

    eval_graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    metric = basic_rnn_train.eval_loop(
        eval_graph, self.eval_dir, self.eval_dir,
        num_training_steps=5, summary_frequency=1).next()
    self.assertTrue('loss' in metric)
    self.assertTrue('log_perplexity' in metric)
    self.assertTrue('accuracy' in metric)