def test_create_global_step(self): step = common.create_global_step() self.assertEqual(step.name, "global_step:0") self.assertEqual(step.dtype, tf.int64) self.assertEqual(step, 0) step.assign_add(1) self.assertEqual(step, 1)
def __init__(self): self.step_counter = common.create_global_step()