예제 #1
0
    def save_checkpoint(self,
                        path,
                        slists,
                        trainer_id=None,
                        local_cache_path=".cache"):
        """
        Serialize objects in slists to path
        Return really saved path and checkpoint_no
        """
        if not self._fs.is_exist(path):
            self._fs.mkdirs(path)
        else:
            assert self._fs.is_dir(path), "path:{} must be a directory".format(
                path)

        max_no = self._get_last_checkpoint_no(path)
        if max_no < 0:
            max_no = -1
        max_no += 1

        real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no)
        tmp_path = "{}.tmp".format(real_path)
        saved_path = tmp_path

        from paddle.distributed.fleet.utils.fs import LocalFS
        local_fs = LocalFS()

        cache_path = None
        if self._fs.need_upload_download():
            cache_path = "{}/{}.{}.saved_cache".format(local_cache_path,
                                                       self._checkpoint_prefix,
                                                       max_no)

            if trainer_id is not None:
                cache_path = "{}.{}".format(cache_path, trainer_id)

            if not local_fs.is_exist(cache_path):
                local_fs.mkdirs(cache_path)
            else:
                assert local_fs.is_dir(cache_path), \
                    "cache path:{} must be a directory".format(cache_path)

            saved_path = cache_path

        for s in slists:
            s.serialize(saved_path)

        if self._fs.need_upload_download():
            self._fs.delete(tmp_path)
            self._fs.upload(cache_path, tmp_path)
            local_fs.delete(cache_path)
        self._fs.mv(tmp_path, real_path)

        return real_path, max_no
예제 #2
0
    def load_checkpoint(self,
                        path,
                        slists,
                        trainer_id,
                        local_cache_path=".cache",
                        checkpoint_no=None,
                        ignore_empty=True):
        """
        Deserialize objects in slists from path
        Return really load path
        """
        if checkpoint_no is None:
            max_no = self._get_last_checkpoint_no(path)

            if not ignore_empty:
                assert max_no >= 0, "Can't find checkpoint"

            if max_no < 0:
                return None

            checkpoint_no = max_no
        else:
            assert isinstance(checkpoint_no, int)
            assert checkpoint_no >= 0

        from paddle.distributed.fleet.utils.fs import LocalFS
        local_fs = LocalFS()
        if self._fs.need_upload_download():
            cache_path = "{}/{}.{}.load_cache".format(local_cache_path,
                                                      self._checkpoint_prefix,
                                                      checkpoint_no)

            if trainer_id is not None:
                cache_path = "{}.{}".format(cache_path, trainer_id)

            if not local_fs.is_exist(local_cache_path):
                local_fs.mkdirs(local_cache_path)
            if local_fs.is_exist(cache_path):
                local_fs.delete(cache_path)

        real_path = "{}/{}.{}".format(path, self._checkpoint_prefix,
                                      checkpoint_no)
        load_path = real_path
        if self._fs.need_upload_download():
            self._fs.download(real_path, cache_path)
            load_path = cache_path

        for s in slists:
            s.deserialize(load_path)

        if self._fs.need_upload_download() and cache_path:
            local_fs.delete(cache_path)

        return real_path
예제 #3
0
    def _test_upload_dir(self, fs):
        # upload dir
        src_file = os.path.abspath("./test_upload_dir")
        dst_file = os.path.abspath("./test_uolpad_dir")
        file1 = os.path.abspath("./test_upload_dir/file1")
        file2 = os.path.abspath("./test_upload_dir/file2")

        local = LocalFS()
        local.mkdirs(src_file)
        local.touch(file1)
        local.touch(file2)

        fs.upload(src_file, dst_file)

        self.assertTrue(fs.is_exist(dst_file))
        fs.delete(dst_file)
        local.delete(src_file)