def restore_checkpoint(self): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network if self.task_parameters.checkpoint_restore_dir: if self.task_parameters.framework_type == Frameworks.tensorflow and\ 'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_dir): # TODO-fixme checkpointing # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so, # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename # pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using # Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as # the checkpoint names defined do not match the actual checkpoint names. checkpoint = self._get_checkpoint_state_tf() else: checkpoint = get_checkpoint_state( self.task_parameters.checkpoint_restore_dir) if checkpoint is None: screen.warning("No checkpoint to restore in: {}".format( self.task_parameters.checkpoint_restore_dir)) else: screen.log_title("Loading checkpoint: {}".format( checkpoint.model_checkpoint_path)) self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path) [ manager.restore_checkpoint( self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers ]
def restore_checkpoint(self): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network if self.task_parameters.checkpoint_restore_path: if os.path.isdir(self.task_parameters.checkpoint_restore_path): # a checkpoint dir if self.task_parameters.framework_type == Frameworks.tensorflow and\ 'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path): # TODO-fixme checkpointing # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so, # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" # filename pattern. The names used are maintained in a CheckpointState protobuf file named # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to # restore the model, as the checkpoint names defined do not match the actual checkpoint names. checkpoint = self._get_checkpoint_state_tf( self.task_parameters.checkpoint_restore_path) else: checkpoint = get_checkpoint_state( self.task_parameters.checkpoint_restore_path) if checkpoint is None: raise ValueError("No checkpoint to restore in: {}".format( self.task_parameters.checkpoint_restore_path)) model_checkpoint_path = checkpoint.model_checkpoint_path checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path # Set the last checkpoint ID - only in the case of the path being a dir chkpt_state_reader = CheckpointStateReader( self.task_parameters.checkpoint_restore_path, checkpoint_state_optional=False) self.checkpoint_id = chkpt_state_reader.get_latest().num + 1 else: # a checkpoint file if self.task_parameters.framework_type == Frameworks.tensorflow: model_checkpoint_path = self.task_parameters.checkpoint_restore_path checkpoint_restore_dir = os.path.dirname( model_checkpoint_path) else: raise ValueError( "Currently restoring a checkpoint using the --checkpoint_restore_file argument is" " only supported when with tensorflow.") screen.log_title( "Loading checkpoint: {}".format(model_checkpoint_path)) self.checkpoint_saver.restore(self.sess, model_checkpoint_path) [ manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers ]
def test_get_checkpoint_state(): files = [ '4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext', '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext' ] with tempfile.TemporaryDirectory() as temp_dir: [open(os.path.join(temp_dir, fn), 'a').close() for fn in files] checkpoint_state = checkpoint.get_checkpoint_state( temp_dir, all_checkpoints=True) assert checkpoint_state.model_checkpoint_path == os.path.join( temp_dir, '4.test.ckpt') assert checkpoint_state.all_model_checkpoint_paths == \ [os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])] reader = checkpoint.CheckpointStateReader( temp_dir, checkpoint_state_optional=False) assert reader.get_latest() is None assert len(reader.get_all()) == 0 reader = checkpoint.CheckpointStateReader(temp_dir) assert reader.get_latest().num == 4 assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4]
def restore_checkpoint(self): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network if self.task_parameters.checkpoint_restore_path: restored_checkpoint_paths = [] for agent_params in self.agents_params: if len(self.agents_params) == 1: agent_checkpoint_restore_path = self.task_parameters.checkpoint_restore_path else: agent_checkpoint_restore_path = os.path.join( self.task_parameters.checkpoint_restore_path, agent_params.name) if os.path.isdir(agent_checkpoint_restore_path): # a checkpoint dir if self.task_parameters.framework_type == Frameworks.tensorflow and\ 'checkpoint' in os.listdir(agent_checkpoint_restore_path): # TODO-fixme checkpointing # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so, # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" # filename pattern. The names used are maintained in a CheckpointState protobuf file named # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to # restore the model, as the checkpoint names defined do not match the actual checkpoint names. raise NotImplementedError( 'Checkpointing not implemented for TF monitored training session' ) else: checkpoint = get_checkpoint_state( agent_checkpoint_restore_path, all_checkpoints=True) if checkpoint is None: raise ValueError( "No checkpoint to restore in: {}".format( agent_checkpoint_restore_path)) model_checkpoint_path = checkpoint.model_checkpoint_path checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path restored_checkpoint_paths.append(model_checkpoint_path) # Set the last checkpoint ID - only in the case of the path being a dir chkpt_state_reader = CheckpointStateReader( agent_checkpoint_restore_path, checkpoint_state_optional=False) self.checkpoint_id = chkpt_state_reader.get_latest( ).num + 1 else: # a checkpoint file if self.task_parameters.framework_type == Frameworks.tensorflow: model_checkpoint_path = agent_checkpoint_restore_path checkpoint_restore_dir = os.path.dirname( model_checkpoint_path) restored_checkpoint_paths.append(model_checkpoint_path) else: raise ValueError( "Currently restoring a checkpoint using the --checkpoint_restore_file argument is" " only supported when with tensorflow.") try: self.checkpoint_saver[agent_params.name].restore( self.sess[agent_params.name], model_checkpoint_path) except Exception as ex: raise ValueError( "Failed to restore {}'s checkpoint: {}".format( agent_params.name, ex)) all_checkpoints = sorted( list(set([c.name for c in checkpoint.all_checkpoints ]))) # remove duplicates :-( if self.num_checkpoints_to_keep < len(all_checkpoints): checkpoint_to_delete = all_checkpoints[ -self.num_checkpoints_to_keep - 1] agent_checkpoint_to_delete = os.path.join( agent_checkpoint_restore_path, checkpoint_to_delete) for file in glob.glob( "{}*".format(agent_checkpoint_to_delete)): os.remove(file) [ manager.restore_checkpoint(checkpoint_restore_dir) for manager in self.level_managers ] [ manager.post_training_commands() for manager in self.level_managers ] screen.log_dict(OrderedDict([ ("Restoring from path", restored_checkpoint_path) for restored_checkpoint_path in restored_checkpoint_paths ]), prefix="Checkpoint")