Beispiel #1
0
    def step_end(self, run_context):
        """
        Save the checkpoint at the end of step.

        Args:
            run_context (RunContext): Context of the train running.
        """
        if _is_role_pserver():
            self._prefix = "PServer_" + str(
                _get_ps_mode_rank()) + "_" + self._prefix
        cb_params = run_context.original_args()
        _make_directory(self._directory)
        # save graph (only once)
        if not self._graph_saved:
            graph_file_name = os.path.join(self._directory,
                                           self._prefix + '-graph.meta')
            if os.path.isfile(graph_file_name) and context.get_context(
                    "mode") == context.GRAPH_MODE:
                os.remove(graph_file_name)
            _save_graph(cb_params.train_network, graph_file_name)
            self._graph_saved = True
        thread_list = threading.enumerate()
        for thread in thread_list:
            if thread.getName() == "asyn_save_ckpt":
                thread.join()
        self._save_ckpt(cb_params)
Beispiel #2
0
    def __init__(self, prefix='CKP', directory=None, config=None):
        super(ModelCheckpoint, self).__init__()
        self._latest_ckpt_file_name = ""
        self._init_time = time.time()
        self._last_time = time.time()
        self._last_time_for_keep = time.time()
        self._last_triggered_step = 0

        if _check_file_name_prefix(prefix):
            self._prefix = prefix
        else:
            raise ValueError(
                "Prefix {} for checkpoint file name invalid, "
                "please check and correct it and then continue.".format(
                    prefix))

        if directory is not None:
            self._directory = _make_directory(directory)
        else:
            self._directory = _cur_dir

        if config is None:
            self._config = CheckpointConfig()
        else:
            if not isinstance(config, CheckpointConfig):
                raise TypeError("config should be CheckpointConfig type.")
            self._config = config

        # get existing checkpoint files
        self._manager = CheckpointManager()
        self._prefix = _chg_ckpt_file_name_if_same_exist(
            self._directory, self._prefix)
        self._graph_saved = False
    def __init__(self, prefix='Exception', directory=None, config=None):
        super(ExceptionCheckpoint, self).__init__()
        signal.signal(signal.SIGTERM, self.save)
        signal.signal(signal.SIGINT, self.save)
        self.cb_params = None

        if not isinstance(prefix, str) or prefix.find('/') >= 0:
            raise ValueError(
                f"For 'ExceptionCheckpoint', the argument 'prefix' must be string and the first letter "
                f"of it can't be \"/\", but got 'prefix' type: {type(prefix)}, 'prefix': {prefix}.")
        if directory is not None:
            self._directory = _make_directory(directory)
        else:
            self._directory = _cur_dir
        self._prefix = _check_bpckpt_file_name_if_same_exist(self._directory,
                                                             prefix)

        if config is None:
            self._config = CheckpointConfig()
        else:
            if not isinstance(config, CheckpointConfig):
                raise TypeError(
                    "For 'ExceptionCheckpoint', the argument 'config' should be CheckpointConfig type, "
                    "but got {}.".format(type(config)))
            self._config = config

        self._append_dict = self._config.append_dict or {}
        self._append_epoch_num = self._append_dict[
            "epoch_num"] if "epoch_num" in self._append_dict else 0
        self._append_step_num = self._append_dict[
            "step_num"] if "step_num" in self._append_dict else 0