示例#1
0
    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if self.task_parameters.checkpoint_restore_dir:
            if self.task_parameters.framework_type == Frameworks.tensorflow and\
                    'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_dir):
                # TODO-fixme checkpointing
                # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
                # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt" filename
                # pattern. The names used are maintained in a CheckpointState protobuf file named 'checkpoint'. Using
                # Coach's '.coach_checkpoint' protobuf file, results in an error when trying to restore the model, as
                # the checkpoint names defined do not match the actual checkpoint names.
                checkpoint = self._get_checkpoint_state_tf()
            else:
                checkpoint = get_checkpoint_state(
                    self.task_parameters.checkpoint_restore_dir)

            if checkpoint is None:
                screen.warning("No checkpoint to restore in: {}".format(
                    self.task_parameters.checkpoint_restore_dir))
            else:
                screen.log_title("Loading checkpoint: {}".format(
                    checkpoint.model_checkpoint_path))
                self.checkpoint_saver.restore(self.sess,
                                              checkpoint.model_checkpoint_path)

            [
                manager.restore_checkpoint(
                    self.task_parameters.checkpoint_restore_dir)
                for manager in self.level_managers
            ]
示例#2
0
    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if self.task_parameters.checkpoint_restore_path:
            if os.path.isdir(self.task_parameters.checkpoint_restore_path):
                # a checkpoint dir
                if self.task_parameters.framework_type == Frameworks.tensorflow and\
                        'checkpoint' in os.listdir(self.task_parameters.checkpoint_restore_path):
                    # TODO-fixme checkpointing
                    # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
                    # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
                    # filename pattern. The names used are maintained in a CheckpointState protobuf file named
                    # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
                    # restore the model, as the checkpoint names defined do not match the actual checkpoint names.
                    checkpoint = self._get_checkpoint_state_tf(
                        self.task_parameters.checkpoint_restore_path)
                else:
                    checkpoint = get_checkpoint_state(
                        self.task_parameters.checkpoint_restore_path)

                if checkpoint is None:
                    raise ValueError("No checkpoint to restore in: {}".format(
                        self.task_parameters.checkpoint_restore_path))
                model_checkpoint_path = checkpoint.model_checkpoint_path
                checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path

                # Set the last checkpoint ID - only in the case of the path being a dir
                chkpt_state_reader = CheckpointStateReader(
                    self.task_parameters.checkpoint_restore_path,
                    checkpoint_state_optional=False)
                self.checkpoint_id = chkpt_state_reader.get_latest().num + 1
            else:
                # a checkpoint file
                if self.task_parameters.framework_type == Frameworks.tensorflow:
                    model_checkpoint_path = self.task_parameters.checkpoint_restore_path
                    checkpoint_restore_dir = os.path.dirname(
                        model_checkpoint_path)
                else:
                    raise ValueError(
                        "Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
                        " only supported when with tensorflow.")

            screen.log_title(
                "Loading checkpoint: {}".format(model_checkpoint_path))

            self.checkpoint_saver.restore(self.sess, model_checkpoint_path)

            [
                manager.restore_checkpoint(checkpoint_restore_dir)
                for manager in self.level_managers
            ]
示例#3
0
def test_get_checkpoint_state():
    files = [
        '4.test.ckpt.ext', '2.test.ckpt.ext', '3.test.ckpt.ext',
        '1.test.ckpt.ext', 'prefix.10.test.ckpt.ext'
    ]
    with tempfile.TemporaryDirectory() as temp_dir:
        [open(os.path.join(temp_dir, fn), 'a').close() for fn in files]
        checkpoint_state = checkpoint.get_checkpoint_state(
            temp_dir, all_checkpoints=True)
        assert checkpoint_state.model_checkpoint_path == os.path.join(
            temp_dir, '4.test.ckpt')
        assert checkpoint_state.all_model_checkpoint_paths == \
               [os.path.join(temp_dir, f[:-4]) for f in sorted(files[:-1])]

        reader = checkpoint.CheckpointStateReader(
            temp_dir, checkpoint_state_optional=False)
        assert reader.get_latest() is None
        assert len(reader.get_all()) == 0

        reader = checkpoint.CheckpointStateReader(temp_dir)
        assert reader.get_latest().num == 4
        assert [ckp.num for ckp in reader.get_all()] == [1, 2, 3, 4]
示例#4
0
    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if self.task_parameters.checkpoint_restore_path:
            restored_checkpoint_paths = []
            for agent_params in self.agents_params:
                if len(self.agents_params) == 1:
                    agent_checkpoint_restore_path = self.task_parameters.checkpoint_restore_path
                else:
                    agent_checkpoint_restore_path = os.path.join(
                        self.task_parameters.checkpoint_restore_path,
                        agent_params.name)
                if os.path.isdir(agent_checkpoint_restore_path):
                    # a checkpoint dir
                    if self.task_parameters.framework_type == Frameworks.tensorflow and\
                            'checkpoint' in os.listdir(agent_checkpoint_restore_path):
                        # TODO-fixme checkpointing
                        # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
                        # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
                        # filename pattern. The names used are maintained in a CheckpointState protobuf file named
                        # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
                        # restore the model, as the checkpoint names defined do not match the actual checkpoint names.
                        raise NotImplementedError(
                            'Checkpointing not implemented for TF monitored training session'
                        )
                    else:
                        checkpoint = get_checkpoint_state(
                            agent_checkpoint_restore_path,
                            all_checkpoints=True)

                    if checkpoint is None:
                        raise ValueError(
                            "No checkpoint to restore in: {}".format(
                                agent_checkpoint_restore_path))
                    model_checkpoint_path = checkpoint.model_checkpoint_path
                    checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
                    restored_checkpoint_paths.append(model_checkpoint_path)

                    # Set the last checkpoint ID - only in the case of the path being a dir
                    chkpt_state_reader = CheckpointStateReader(
                        agent_checkpoint_restore_path,
                        checkpoint_state_optional=False)
                    self.checkpoint_id = chkpt_state_reader.get_latest(
                    ).num + 1
                else:
                    # a checkpoint file
                    if self.task_parameters.framework_type == Frameworks.tensorflow:
                        model_checkpoint_path = agent_checkpoint_restore_path
                        checkpoint_restore_dir = os.path.dirname(
                            model_checkpoint_path)
                        restored_checkpoint_paths.append(model_checkpoint_path)
                    else:
                        raise ValueError(
                            "Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
                            " only supported when with tensorflow.")

                try:
                    self.checkpoint_saver[agent_params.name].restore(
                        self.sess[agent_params.name], model_checkpoint_path)
                except Exception as ex:
                    raise ValueError(
                        "Failed to restore {}'s checkpoint: {}".format(
                            agent_params.name, ex))

                all_checkpoints = sorted(
                    list(set([c.name for c in checkpoint.all_checkpoints
                              ])))  # remove duplicates :-(
                if self.num_checkpoints_to_keep < len(all_checkpoints):
                    checkpoint_to_delete = all_checkpoints[
                        -self.num_checkpoints_to_keep - 1]
                    agent_checkpoint_to_delete = os.path.join(
                        agent_checkpoint_restore_path, checkpoint_to_delete)
                    for file in glob.glob(
                            "{}*".format(agent_checkpoint_to_delete)):
                        os.remove(file)

            [
                manager.restore_checkpoint(checkpoint_restore_dir)
                for manager in self.level_managers
            ]
            [
                manager.post_training_commands()
                for manager in self.level_managers
            ]

            screen.log_dict(OrderedDict([
                ("Restoring from path", restored_checkpoint_path)
                for restored_checkpoint_path in restored_checkpoint_paths
            ]),
                            prefix="Checkpoint")