예제 #1
0
파일: __init__.py 프로젝트: neuzxy/Paddle
    def clean_redundant_check_points(self,
                                     root_path,
                                     fs=LocalFS(),
                                     checkpoint_num=1):
        max_no = self._get_last_checkpoint_no(root_path, fs)
        if max_no < 0:
            return

        if checkpoint_num < 1:
            checkpoint_num = 1

        dirs = fs.list_dirs(root_path)
        for dir in dirs:
            g = dir.split(".")
            if len(g) != 2:
                continue

            if g[0] != self._checkoint_prefix:
                continue

            try:
                n = int(g[1])
                if n <= max_no - checkpoint_num:
                    path = "{}/{}.{}".format(root_path, self._checkoint_prefix,
                                             n)
                    fs.rmr(path)
            except Exception as e:
                print(e)
                continue
예제 #2
0
파일: __init__.py 프로젝트: neuzxy/Paddle
    def save_check_point(self,
                         executor,
                         path,
                         train_status,
                         main_program=None,
                         fs=LocalFS(),
                         local_cache_path=".cache",
                         remain_all_checkpoint=True):
        """
        This function save persistables and current epoch num to path.
        """

        if main_program == None:
            main_program = self._transpiled_program

        if not fs.stat(path):
            fs.mkdir(path)

        max_no = self._get_last_checkpoint_no(path, fs=fs)
        if max_no < 0:
            max_no = -1

        real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no + 1)
        tmp_path = "{}.tmp".format(real_path)
        saved_path = tmp_path

        local_fs = LocalFS()

        cache_path = None
        if fs.need_upload_download():
            cache_path = "{}/{}.{}.saved_cache".format(local_cache_path,
                                                       self._checkoint_prefix,
                                                       max_no + 1)
            if not local_fs.stat(cache_path):
                local_fs.mkdir(cache_path)
            saved_path = cache_path

        self.save_persistables(executor=executor,
                               dirname=saved_path,
                               main_program=main_program,
                               filename=self._param_file_name)
        self._save_train_status(path=saved_path, train_status=train_status)

        if fs.need_upload_download():
            fs.delete(tmp_path)
            fs.upload(cache_path, tmp_path)
        fs.mv(tmp_path, real_path)

        if not remain_all_checkpoint:
            self.clean_redundant_check_points(path)
예제 #3
0
파일: __init__.py 프로젝트: neuzxy/Paddle
    def load_check_point(self,
                         executor,
                         path,
                         trainer_id,
                         main_program=None,
                         fs=LocalFS(),
                         local_cache_path=".cache",
                         ignore_empty=True):
        """
        This function load persistables and current epoch num from path.
        """
        max_no = self._get_last_checkpoint_no(path, fs)

        if not ignore_empty:
            assert max_no >= 0, "Can't find checkpoint"

        if max_no < 0:
            return None

        local_fs = LocalFS()
        if fs.need_upload_download():
            cache_path = "{}/{}.{}.load_cache.{}".format(
                local_cache_path, self._checkoint_prefix, max_no, trainer_id)
            if local_fs.stat(cache_path):
                local_fs.delete(cache_path)

        real_path = "{}/{}.{}".format(path, self._checkoint_prefix, max_no)
        load_path = real_path
        if fs.need_upload_download():
            fs.download(real_path, cache_path)
            load_path = cache_path

        if main_program == None:
            main_program = self._transpiled_program

        io.load_persistables(executor=executor,
                             dirname=load_path,
                             main_program=main_program,
                             filename=self._param_file_name)

        return self._load_train_status(load_path)
예제 #4
0
 def test_local_check_point(self):
     fs = LocalFS()
     dir_path = "./my_paddle_model"
     self._test_check_point(fs, dir_path)