def testLoadLatestCheckpoint(self): exp_checkpointer = checkpointer.Checkpointer(self._test_subdir) first_iter = 1729 exp_checkpointer.save_checkpoint(first_iter, first_iter) second_iter = first_iter + 1 exp_checkpointer.save_checkpoint(second_iter, second_iter) self.assertEqual( second_iter, checkpointer.get_latest_checkpoint_number(self._test_subdir))
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): self.checkpointer = checkpointer.Checkpointer(self.checkpoint_dir, checkpoint_file_prefix) self.start_iteration = 0 latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self.checkpoint_dir) if latest_checkpoint_version >= 0: experiment_data = self.checkpointer.load_checkpoint( latest_checkpoint_version) if self.agent.unbundle(self.checkpoint_dir, latest_checkpoint_version, experiment_data): assert 'logs' in experiment_data assert 'current_iteration' in experiment_data self.logger.data = experiment_data['logs'] self.start_iteration = experiment_data['current_iteration'] + 1 print('Reloaded checkpoint and will start from iteration ', self.start_iteration)
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoint in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the agent is able to successfully unbundle, this method will verify that the unbundled data contains the keys,'logs' and 'current_iteration'. It will then load the `Logger`'s data from the bundle, and will return the iteration number keyed by 'current_iteration' as one of the return values (along with the `Checkpointer` object). Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: int, the iteration number to start the experiment from. experiment_checkpointer: `Checkpointer` object for the experiment. """ # self._checkpoint_dir = base_dir + "/checkpoints" self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, checkpoint_file_prefix) self._start_iteration = 0 # Check if checkpoint exists. Note that the existence of checkpoint 0 means # that we have finished iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self._checkpoint_dir) if latest_checkpoint_version >= 0: experiment_data = self._checkpointer.load_checkpoint( latest_checkpoint_version) if self._agent.unbundle(self._checkpoint_dir, latest_checkpoint_version, experiment_data): assert 'logs' in experiment_data assert 'current_iteration' in experiment_data self._logger.data = experiment_data['logs'] self._start_iteration = experiment_data['current_iteration'] + 1 tf.logging.info( 'Reloaded checkpoint and will start from iteration %d', self._start_iteration)
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoint in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the agent is able to successfully unbundle, this method will verify that the unbundled data contains the keys,'logs' and 'current_iteration'. It will then load the `Logger`'s data from the bundle, and will return the iteration number keyed by 'current_iteration' as one of the return values (along with the `Checkpointer` object). Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: int, the iteration number to start the experiment from. experiment_checkpointer: `Checkpointer` object for the experiment. """ self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, checkpoint_file_prefix) self._start_iteration = 0 # Check if checkpoint exists. Note that the existence of checkpoint 0 means # that we have finished iteration 0 (so we will start from iteration 1). latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( self._checkpoint_dir) if latest_checkpoint_version >= 0: experiment_data = self._checkpointer.load_checkpoint( latest_checkpoint_version) if self._agent.unbundle( self._checkpoint_dir, latest_checkpoint_version, experiment_data): assert 'logs' in experiment_data assert 'current_iteration' in experiment_data self._logger.data = experiment_data['logs'] self._start_iteration = experiment_data['current_iteration'] + 1 tf.logging.info('Reloaded checkpoint and will start from iteration %d', self._start_iteration)
def testLoadLatestCheckpointWithEmptyDir(self): self.assertEqual( -1, checkpointer.get_latest_checkpoint_number(self._test_subdir))
def testLoadLatestCheckpointWithInvalidDir(self): self.assertEqual( -1, checkpointer.get_latest_checkpoint_number('/does/not/exist'))