def test_global_step_name(self):
        with ops.Graph().as_default() as g, session_lib.Session() as sess:
            with variable_scope.variable_scope('bar'):
                foo_step = variable_scope.get_variable(
                    'foo',
                    initializer=0,
                    trainable=False,
                    collections=[
                        ops.GraphKeys.GLOBAL_STEP,
                        ops.GraphKeys.GLOBAL_VARIABLES
                    ])
            train_op = state_ops.assign_add(foo_step, 1)
            summary_writer = fake_summary_writer.FakeSummaryWriter(
                self.log_dir, g)
            hook = basic_session_run_hooks.StepCounterHook(
                summary_writer=summary_writer,
                every_n_steps=1,
                every_n_secs=None)

            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(test_case=self,
                                            expected_logdir=self.log_dir,
                                            expected_graph=g,
                                            expected_summaries={})
            self.assertTrue(summary_writer.summaries,
                            'No summaries were created.')
            self.assertItemsEqual([2], summary_writer.summaries.keys())
            summary_value = summary_writer.summaries[2][0].value[0]
            self.assertEqual('bar/foo/sec', summary_value.tag)
    def test_step_counter_every_n_secs(self):
        with ops.Graph().as_default() as g, session_lib.Session() as sess:
            global_step = variables.get_or_create_global_step()
            train_op = state_ops.assign_add(global_step, 1)
            summary_writer = fake_summary_writer.FakeSummaryWriter(
                self.log_dir, g)
            hook = basic_session_run_hooks.StepCounterHook(
                summary_writer=summary_writer,
                every_n_steps=None,
                every_n_secs=0.1)

            hook.begin()
            sess.run(variables_lib.global_variables_initializer())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(test_case=self,
                                            expected_logdir=self.log_dir,
                                            expected_graph=g,
                                            expected_summaries={})
            self.assertTrue(summary_writer.summaries,
                            'No summaries were created.')
            self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
            for summary in summary_writer.summaries.values():
                summary_value = summary[0].value[0]
                self.assertEqual('global_step/sec', summary_value.tag)
                self.assertGreater(summary_value.simple_value, 0)
Пример #3
0
 def test_step_counter_every_n_steps(self):
   with ops.Graph().as_default() as g, session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(1)
     summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
     hook = basic_session_run_hooks.StepCounterHook(
         summary_writer=summary_writer, every_n_steps=10)
     hook.begin()
     sess.run(variables_lib.global_variables_initializer())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         time.sleep(0.01)
         mon_sess.run(train_op)
       # logging.warning should not be called.
       self.assertIsNone(mock_log.call_args)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=self.log_dir,
         expected_graph=g,
         expected_summaries={})
     self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
     for step in [11, 21]:
       summary_value = summary_writer.summaries[step][0].value[0]
       self.assertEqual('global_step/sec', summary_value.tag)
       self.assertGreater(summary_value.simple_value, 0)
Пример #4
0
  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
    tensor = state_ops.assign_add(var, 1.0)
    self.summary_op = summary_lib.scalar('my_summary', tensor)

    with variable_scope.variable_scope('foo', use_resource=True):
      global_step = variables.get_or_create_global_step()
    self.train_op = state_ops.assign_add(global_step, 1)
Пример #5
0
  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variables_lib.Variable(0.0)
    tensor = state_ops.assign_add(var, 1.0)
    tensor2 = tensor * 2
    self.summary_op = summary_lib.scalar('my_summary', tensor)
    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)

    global_step = variables.get_or_create_global_step()
    self.train_op = state_ops.assign_add(global_step, 1)