예제 #1
0
    def execute(self, argv):
        args = self.parser.parse_args()

        if not args.module_name or not args.model_dir:
            ConvertCommand.show_help()
            return False
        self.module = args.module_name
        self.version = args.module_version if args.module_version is not None else '1.0.0'
        self.src = args.model_dir
        if not os.path.isdir(self.src):
            print('`{}` is not exists or not a directory path'.format(self.src))
            return False
        self.dest = args.output_dir if args.output_dir is not None else os.path.join(
            '{}_{}'.format(self.module, str(time.time())))

        CacheUpdater("hub_convert", self.module, self.version).start()
        os.makedirs(self.dest)

        with tmp_dir() as _dir:
            self._tmp_dir = _dir
            self.create_module_py()
            self.create_init_py()
            self.create_serving_demo_py()
            self.create_module_tar()

        print('The converted module is stored in `{}`.'.format(self.dest))

        return True
예제 #2
0
def download(name,
             save_path,
             version=None,
             decompress=True,
             resource_type='Model',
             extra={}):
    file = os.path.join(save_path, name)
    file = os.path.realpath(file)
    if os.path.exists(file):
        return

    if not hub.HubServer()._server_check():
        raise ServerConnectionError

    search_result = hub.HubServer().get_resource_url(
        name, resource_type=resource_type, version=version, extra=extra)

    if not search_result:
        raise ResourceNotFoundError(name, version)
    CacheUpdater("x_download", name, version).start()
    url = search_result['url']

    with tmp_dir() as _dir:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        _, _, savefile = default_downloader.download_file(url=url,
                                                          save_path=_dir,
                                                          print_progress=True)
        if tarfile.is_tarfile(savefile) and decompress:
            _, _, savefile = default_downloader.uncompress(file=savefile,
                                                           print_progress=True)
        shutil.move(savefile, file)
예제 #3
0
def create_module(directory, name, author, email, module_type, summary,
                  version):
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
            for dirname, _, files in os.walk(module_dir):
                for file in files:
                    tar.add(os.path.join(dirname, file).replace(base_dir, "."))

            os.chdir(_cwd)
예제 #4
0
파일: convert.py 프로젝트: zolagz/PaddleHub
    def run(self, module, version, src, dest):

        self.module = module
        self.version = version
        self.src = src
        self.dest = dest

        os.makedirs(self.dest)

        with tmp_dir() as _dir:
            self._tmp_dir = _dir
            self.create_module_py()
            self.create_init_py()
            self.create_serving_demo_py()
            self.create_module_tar()

        return True
예제 #5
0
파일: nlp_module.py 프로젝트: wuhuaha/beike
    def __init__(self,
                 name=None,
                 directory=None,
                 module_dir=None,
                 version=None,
                 max_seq_len=128,
                 **kwargs):
        if not directory:
            return
        super(TransformerModule, self).__init__(
            name=name,
            directory=directory,
            module_dir=module_dir,
            version=version,
            **kwargs)

        self.max_seq_len = max_seq_len
        if version_compare(paddle.__version__, '1.8'):
            with tmp_dir() as _dir:
                input_dict, output_dict, program = self.context(
                    max_seq_len=max_seq_len)
                fluid.io.save_inference_model(
                    dirname=_dir,
                    main_program=program,
                    feeded_var_names=[
                        input_dict['input_ids'].name,
                        input_dict['position_ids'].name,
                        input_dict['segment_ids'].name,
                        input_dict['input_mask'].name
                    ],
                    target_vars=[
                        output_dict["pooled_output"],
                        output_dict["sequence_output"]
                    ],
                    executor=fluid.Executor(fluid.CPUPlace()))

                with fluid.dygraph.guard():
                    self.model_runner = fluid.dygraph.StaticModelRunner(_dir)
예제 #6
0
    def install_module(self,
                       module_name=None,
                       module_dir=None,
                       module_package=None,
                       module_version=None,
                       upgrade=False,
                       extra=None):
        md5_value = installed_module_version = None
        from_user_dir = True if module_dir else False
        with tmp_dir() as _dir:
            if module_name:
                self.all_modules(update=True)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
                    if not module_version or module_version == self.modules_dict[
                            module_name][1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]

                search_result = hub.HubServer().get_module_url(
                    module_name, version=module_version, extra=extra)
                name = search_result.get('name', None)
                url = search_result.get('url', None)
                md5_value = search_result.get('md5', None)
                installed_module_version = search_result.get('version', None)
                if not url or (module_version is not None
                               and installed_module_version != module_version
                               ) or (name != module_name):
                    if hub.HubServer()._server_check() is False:
                        tips = "Request Hub-Server unsuccessfully, please check your network."
                        return False, tips, None
                    module_versions_info = hub.HubServer().search_module_info(
                        module_name)
                    if module_versions_info is None:
                        tips = "Can't find module %s, please check your spelling." \
                               % (module_name)
                    elif module_version is not None and module_version not in [
                            item[1] for item in module_versions_info
                    ]:
                        tips = "Can't find module %s with version %s, all versions are listed below." \
                               % (module_name, module_version)
                        tips += paint_modules_info(module_versions_info)
                    else:
                        tips = "The version of PaddlePaddle(%s) or PaddleHub(%s) can not match module, please upgrade your PaddlePaddle or PaddleHub according to the form below." \
                               % (sys_paddle_version, sys_hub_verion)
                        tips += paint_modules_info(module_versions_info)

                    return False, tips, None

                result, tips, module_zip_file = default_downloader.download_file(
                    url=url,
                    save_path=_dir,
                    save_name=module_name,
                    replace=True,
                    print_progress=True)
                result, tips, module_dir = default_downloader.uncompress(
                    file=module_zip_file,
                    dirname=os.path.join(_dir, "tmp_module"),
                    delete_file=True,
                    print_progress=True)

            if module_package:
                with tarfile.open(module_package, "r:gz") as tar:
                    file_names = tar.getnames()
                    size = len(file_names) - 1
                    module_name = file_names[0]
                    module_dir = os.path.join(_dir, module_name)
                    for index, file_name in enumerate(file_names):
                        tar.extract(file_name, _dir)
                    if "-" in module_name:
                        module_name = module_name.replace("-", "_")
                        new_module_dir = os.path.join(_dir, module_name)
                        shutil.move(module_dir, new_module_dir)
                        module_dir = new_module_dir
                    module_name = hub.Module(directory=module_dir).name

            if from_user_dir:
                module_name = hub.Module(directory=module_dir).name
                module_version = hub.Module(directory=module_dir).version
                self.all_modules(update=False)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
                    if module_version == module_info[1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]

            if module_dir:
                if md5_value:
                    with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
                              "w") as fp:
                        fp.write(md5_value)

                save_path = os.path.join(MODULE_HOME,
                                         module_name.replace("-", "_"))
                if save_path != module_dir:
                    if os.path.exists(save_path):
                        shutil.rmtree(save_path)
                    if from_user_dir:
                        shutil.copytree(module_dir, save_path)
                    else:
                        shutil.move(module_dir, save_path)
                module_dir = save_path
                tips = "Successfully installed %s" % module_name
                if installed_module_version:
                    tips += "-%s" % installed_module_version
                return True, tips, (module_dir, installed_module_version)
            tips = "Download %s-%s failed" % (module_name, module_version)
            return False, tips, module_dir
예제 #7
0
    def install_module(self,
                       module_name=None,
                       module_dir=None,
                       module_package=None,
                       module_version=None,
                       upgrade=False,
                       extra=None):
        md5_value = installed_module_version = None
        from_user_dir = True if module_dir else False
        with tmp_dir() as _dir:
            if module_name:
                self.all_modules(update=True)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
                    if not module_version or module_version == self.modules_dict[
                            module_name][1]:
                        module_dir = self.modules_dict[module_name][0]
                        module_tag = module_name if not module_version else '%s-%s' % (
                            module_name, module_version)
                        tips = "Module %s already installed in %s" % (
                            module_tag, module_dir)
                        return True, tips, self.modules_dict[module_name]

                search_result = hub.HubServer().get_module_url(
                    module_name, version=module_version, extra=extra)
                name = search_result.get('name', None)
                url = search_result.get('url', None)
                md5_value = search_result.get('md5', None)
                installed_module_version = search_result.get('version', None)
                if not url or (module_version is not None
                               and installed_module_version != module_version
                               ) or (name != module_name):
                    if hub.HubServer()._server_check() is False:
                        tips = "Request Hub-Server unsuccessfully, please check your network."
                        return False, tips, None
                    module_versions_info = hub.HubServer().search_module_info(
                        module_name)
                    if module_versions_info is not None and len(
                            module_versions_info) > 0:

                        if utils.is_windows():
                            placeholders = [20, 8, 14, 14]
                        else:
                            placeholders = [30, 8, 16, 16]
                        tp = TablePrinter(titles=[
                            "ResourceName", "Version", "PaddlePaddle",
                            "PaddleHub"
                        ],
                                          placeholders=placeholders)
                        module_versions_info.sort(
                            key=cmp_to_key(utils.sort_version_key))
                        for resource_name, resource_version, paddle_version, \
                            hub_version in module_versions_info:
                            colors = ["yellow", None, None, None]

                            tp.add_line(contents=[
                                resource_name, resource_version,
                                utils.strflist_version(paddle_version),
                                utils.strflist_version(hub_version)
                            ],
                                        colors=colors)
                        tips = "The version of PaddlePaddle or PaddleHub " \
                               "can not match module, please upgrade your " \
                               "PaddlePaddle or PaddleHub according to the form " \
                               "below." + tp.get_text()
                    else:
                        tips = "Can't find module %s" % module_name
                        if module_version:
                            tips += " with version %s" % module_version
                    return False, tips, None

                result, tips, module_zip_file = default_downloader.download_file(
                    url=url,
                    save_path=_dir,
                    save_name=module_name,
                    replace=True,
                    print_progress=True)
                result, tips, module_dir = default_downloader.uncompress(
                    file=module_zip_file,
                    dirname=MODULE_HOME,
                    delete_file=True,
                    print_progress=True)

            if module_package:
                with tarfile.open(module_package, "r:gz") as tar:
                    file_names = tar.getnames()
                    size = len(file_names) - 1
                    module_dir = os.path.split(file_names[0])[0]
                    module_dir = os.path.join(_dir, module_dir)
                    for index, file_name in enumerate(file_names):
                        tar.extract(file_name, _dir)

            if module_dir:
                if not module_name:
                    module_name = hub.Module(directory=module_dir).name
                self.all_modules(update=False)
                module_info = self.modules_dict.get(module_name, None)
                if module_info:
                    module_dir = self.modules_dict[module_name][0]
                    module_tag = module_name if not module_version else '%s-%s' % (
                        module_name, module_version)
                    tips = "Module %s already installed in %s" % (module_tag,
                                                                  module_dir)
                    return True, tips, self.modules_dict[module_name]

            if module_dir:
                if md5_value:
                    with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
                              "w") as fp:
                        fp.write(md5_value)

                save_path = os.path.join(MODULE_HOME, module_name)
                if os.path.exists(save_path):
                    shutil.move(save_path)
                if from_user_dir:
                    shutil.copytree(module_dir, save_path)
                else:
                    shutil.move(module_dir, save_path)
                module_dir = save_path
                tips = "Successfully installed %s" % module_name
                if installed_module_version:
                    tips += "-%s" % installed_module_version
                return True, tips, (module_dir, installed_module_version)
            tips = "Download %s-%s failed" % (module_name, module_version)
            return False, tips, module_dir
예제 #8
0
def create_module(directory, name, author, email, module_type, summary,
                  version):
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
            module_dir = module_dir.replace(base_dir, ".")
            tar.add(module_dir, recursive=False)
            files = []
            for dirname, _, subfiles in os.walk(module_dir):
                for file in subfiles:
                    #                     if file.startswith("."):
                    #                         continue
                    files.append(os.path.join(dirname, file))

            total_length = len(files)
            print("Create Module {}-{}".format(name, version))
            for index, file in enumerate(files):
                done = int(float(index) / total_length * 50)
                progress("[%-50s] %.2f%%" %
                         ('=' * done, float(index / total_length * 100)))
                tar.add(file)
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
            print("Module package saved as {}".format(save_file))
            os.chdir(_cwd)