def test_metrics_v1(self):
    api_label = util._CHECKPOINT_V1
    prefix = os.path.join(self.get_temp_dir(), 'ckpt')

    with self.cached_session():
      ckpt = util.CheckpointV1()
      v = variables_lib.Variable(1.)
      self.evaluate(v.initializer)
      ckpt.v = v
      self.assertEqual(self._get_time_saved(api_label), 0.0)
      self.assertEqual(self._get_write_histogram_proto(api_label).num, 0.0)

      for i in range(3):
        time_saved = self._get_time_saved(api_label)
        ckpt_path = ckpt.write(file_prefix=prefix)
        filesize = util._get_checkpoint_size(ckpt_path)
        self.assertEqual(self._get_checkpoint_size(api_label, filesize), i + 1)
        self.assertGreaterEqual(self._get_time_saved(api_label), time_saved)

    self.assertEqual(self._get_write_histogram_proto(api_label).num, 3.0)

    self.assertEqual(self._get_read_histogram_proto(api_label).num, 0.0)
    time_saved = self._get_time_saved(api_label)
    ckpt.restore(ckpt_path)
    self.assertEqual(self._get_read_histogram_proto(api_label).num, 1.0)
    # Restoring a checkpoint in the same "job" does not increase training time
    # saved.
    self.assertEqual(self._get_time_saved(api_label), time_saved)
Exemplo n.º 2
0
 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.CheckpointV1(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
Exemplo n.º 3
0
    def testUsageGraph(self):
        """Expected usage when graph building."""
        with context.graph_mode():
            num_training_steps = 10
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            for training_continuation in range(3):
                with ops.Graph().as_default():
                    model = MyModel()
                    optimizer = adam.Adam(0.001)
                    root = trackable_utils.CheckpointV1(optimizer=optimizer,
                                                        model=model)
                    input_value = constant_op.constant([[3.]])
                    with backprop.GradientTape() as tape:
                        loss = model(input_value)
                    variables = model.trainable_variables
                    gradients = tape.gradient(loss, variables)
                    train_op = optimizer.apply_gradients(
                        zip(gradients, variables))

                    checkpoint_path = checkpoint_management.latest_checkpoint(
                        checkpoint_directory)
                    with self.session(
                            graph=ops.get_default_graph()) as session:
                        status = root.restore(save_path=checkpoint_path)
                        status.initialize_or_restore(session=session)
                        if checkpoint_path is None:
                            self.assertEqual(0, training_continuation)
                            with self.assertRaises(AssertionError):
                                status.assert_consumed()
                            with self.assertRaises(AssertionError):
                                status.assert_existing_objects_matched()
                        else:
                            status.assert_consumed()
                            status.assert_existing_objects_matched()
                        for _ in range(num_training_steps):
                            session.run(train_op)
                        root.save(file_prefix=checkpoint_prefix,
                                  session=session)
                        self.assertEqual(
                            (training_continuation + 1) * num_training_steps,
                            session.run(root.optimizer.iterations))
                        self.assertEqual(training_continuation + 1,
                                         session.run(root.save_counter))
Exemplo n.º 4
0
  def test_metrics_v1(self):
    label = 'V1'
    prefix = os.path.join(self.get_temp_dir(), 'ckpt')

    with self.cached_session():
      ckpt = util.CheckpointV1()
      v = variables_lib.Variable(1.)
      self.evaluate(v.initializer)
      ckpt.v = v
      self.assertEqual(self._get_time_saved(label), 0.0)
      self.assertEqual(self._get_write_durations(label).num, 0.0)
      for _ in range(3):
        time_saved = self._get_time_saved(label)
        ckpt_path = ckpt.write(file_prefix=prefix)
        self.assertGreaterEqual(self._get_time_saved(label), time_saved)

    self.assertEqual(self._get_write_durations(label).num, 3.0)

    self.assertEqual(self._get_read_durations(label).num, 0.0)
    time_saved = self._get_time_saved(label)
    ckpt.restore(ckpt_path)
    self.assertEqual(self._get_read_durations(label).num, 1.0)
    # Restoring a checkpoint does not increase training time saved.
    self.assertEqual(self._get_time_saved(label), time_saved)