Esempio n. 1
0
  def testEvalLoop(self):
    train_graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    list(basic_rnn_train.training_loop(
        train_graph, self.eval_dir, num_training_steps=5, summary_frequency=1))

    eval_graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    metric = basic_rnn_train.eval_loop(
        eval_graph, self.eval_dir, self.eval_dir,
        num_training_steps=5, summary_frequency=1).next()
    self.assertTrue('loss' in metric)
    self.assertTrue('log_perplexity' in metric)
    self.assertTrue('accuracy' in metric)
Esempio n. 2
0
 def testRnnLayerSize(self):
   hparams = '{"rnn_layer_sizes":[100, 100]}'
   graph = basic_rnn_train.make_graph(
       sequence_example_file=self.sequence_example_file,
       hparams_string=hparams)
   op_names = [op.name for op in graph.get_operations()]
   self.assertTrue(StringsContainSubstrings(
       op_names,
       ['RNN/MultiRNNCell/Cell0/LSTMCell', 'RNN/MultiRNNCell/Cell1/LSTMCell']))
Esempio n. 3
0
  def testDynamicRnnTrainingLoop(self):
    graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    metrics = list(basic_rnn_train.training_loop(
        graph, self.train_dir, num_training_steps=5, summary_frequency=1))

    for metric in metrics:
      self.assertTrue(metric['loss'] >= 0)
      self.assertTrue(metric['accuracy'] >= 0)
Esempio n. 4
0
  def testDynamicRnnGraphCorrectness(self):
    graph = basic_rnn_train.make_graph(
        sequence_example_file=self.sequence_example_file)
    op_names = [op.name for op in graph.get_operations()]

    self.assertTrue(StringsContainSubstrings(
        op_names,
        ['RNN/MultiRNNCell/Cell0/LSTMCell',
         'RNN/while/MultiRNNCell/Cell0/LSTMCell']))
    self.assertTrue(StringsExcludeSubstrings(
        op_names,
        ['RNN/MultiRNNCell_1/Cell0/LSTMCell', 'InputQueueingStateSaver']))