예제 #1
0
    def __init__(self,
                 name=None,
                 module_dir=None,
                 signatures=None,
                 module_info=None,
                 assets=None,
                 processor=None,
                 extra_info=None,
                 version=None):
        self.desc = module_desc_pb2.ModuleDesc()
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
        self.default_signature = None
        self.module_info = None
        self.processor = None
        self.extra_info = {} if extra_info is None else extra_info
        if not isinstance(self.extra_info, dict):
            raise TypeError(
                "The extra_info should be an instance of python dict")

        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

        fp_lock = open(os.path.join(CONF_HOME, 'config.json'))
        lock.flock(fp_lock, lock.LOCK_EX)
        if name:
            self._init_with_name(name=name, version=version)
            lock.flock(fp_lock, lock.LOCK_UN)
        elif module_dir:
            self._init_with_module_file(module_dir=module_dir[0])
            lock.flock(fp_lock, lock.LOCK_UN)
            name = module_dir[0].split("/")[-1]
            version = module_dir[1]
        elif signatures:
            if processor:
                if not issubclass(processor, BaseProcessor):
                    raise TypeError(
                        "Processor shoule be an instance of paddlehub.BaseProcessor"
                    )
            if assets:
                self.assets = utils.to_list(assets)
                # for asset in assets:
                #     utils.check_path(assets)
            self.processor = processor
            self._generate_module_info(module_info)
            self._init_with_signature(signatures=signatures)
            lock.flock(fp_lock, lock.LOCK_UN)
        else:
            lock.flock(fp_lock, lock.LOCK_UN)
            raise ValueError("Module initialized parameter is empty")
        CacheUpdater(name, version).start()
예제 #2
0
파일: module.py 프로젝트: wxm2020/PaddleHub
    def __init__(self, name=None, directory=None, module_dir=None,
                 version=None):
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
        self.default_signature = None
        self.processor = None
        self.extra_info = {}
        self._code_version = "v1"

        # parse desc
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
예제 #3
0
 def check_module_valid(self, module_path):
     #TODO(wuzewu): code
     info = {}
     try:
         desc_pb_path = os.path.join(module_path, 'module_desc.pb')
         if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
             desc = module_desc_pb2.ModuleDesc()
             with open(desc_pb_path, "rb") as fp:
                 desc.ParseFromString(fp.read())
             info['version'] = desc.attr.map.data["module_info"].map.data[
                 "version"].s
     except:
         return False, None
     return True, info
예제 #4
0
    def __init__(self,
                 name=None,
                 module_dir=None,
                 signatures=None,
                 module_info=None,
                 assets=None,
                 processor=None,
                 extra_info=None,
                 version=None):
        self.desc = module_desc_pb2.ModuleDesc()
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
        self.default_signature = None
        self.module_info = None
        self.processor = None
        self.extra_info = {} if extra_info is None else extra_info
        if not isinstance(self.extra_info, dict):
            raise TypeError(
                "The extra_info should be an instance of python dict")

        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

        # TODO(wuzewu): print more module loading info log
        if name:
            self._init_with_name(name=name, version=version)
        elif module_dir:
            self._init_with_module_file(module_dir=module_dir[0])
        elif signatures:
            if processor:
                if not issubclass(processor, BaseProcessor):
                    raise TypeError(
                        "Processor shoule be an instance of paddlehub.BaseProcessor"
                    )
            if assets:
                self.assets = utils.to_list(assets)
                # for asset in assets:
                #     utils.check_path(assets)
            self.processor = processor
            self._generate_module_info(module_info)
            self._init_with_signature(signatures=signatures)
        else:
            raise ValueError("Module initialized parameter is empty")
예제 #5
0
def create_module(directory, name, author, email, module_type, summary,
                  version):
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
            for dirname, _, files in os.walk(module_dir):
                for file in files:
                    tar.add(os.path.join(dirname, file).replace(base_dir, "."))

            os.chdir(_cwd)
예제 #6
0
 def check_module_valid(self, module_path):
     try:
         desc_pb_path = os.path.join(module_path, 'module_desc.pb')
         if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
             info = {}
             desc = module_desc_pb2.ModuleDesc()
             with open(desc_pb_path, "rb") as fp:
                 desc.ParseFromString(fp.read())
             info['version'] = desc.attr.map.data["module_info"].map.data[
                 "version"].s
             return True, info
         else:
             logger.warning(
                 "%s does not exist, the module will be reinstalled" %
                 desc_pb_path)
     except:
         pass
     return False, None
예제 #7
0
 def check_module_valid(self, module_path):
     try:
         desc_pb_path = os.path.join(module_path, 'module_desc.pb')
         if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
             info = {}
             desc = module_desc_pb2.ModuleDesc()
             with open(desc_pb_path, "rb") as fp:
                 desc.ParseFromString(fp.read())
             info['version'] = desc.attr.map.data["module_info"].map.data[
                 "version"].s
             info['name'] = desc.attr.map.data["module_info"].map.data[
                 "name"].s
             return True, info
         else:
             module_file = os.path.realpath(
                 os.path.join(module_path, 'module.py'))
             if os.path.exists(module_file):
                 basename = os.path.split(module_path)[-1]
                 dirname = os.path.join(
                     *list(os.path.split(module_path)[:-1]))
                 sys.path.insert(0, dirname)
                 _module = importlib.import_module(
                     "{}.module".format(basename))
                 for _item, _cls in inspect.getmembers(
                         _module, inspect.isclass):
                     _item = _module.__dict__[_item]
                     _file = os.path.realpath(
                         sys.modules[_item.__module__].__file__)
                     if issubclass(
                             _item,
                             hub.Module) and _file.startswith(module_file):
                         version = _item._version
                         break
                 sys.path.pop(0)
                 return True, {'version': version, 'name': _item._name}
             logger.warning(
                 "%s does not exist, the module will be reinstalled" %
                 desc_pb_path)
     except:
         pass
     return False, None
예제 #8
0
    def __init__(self,
                 name=None,
                 directory=None,
                 module_dir=None,
                 version=None):
        # Avoid module being initialized multiple times
        if not directory or id(self) in Module._record:
            return
        Module._record[id(self)] = True

        mod = self.__class__.__module__ + "." + self.__class__.__name__
        if mod in _module_runnable_func:
            _run_func_name = _module_runnable_func[mod]
            self._run_func = getattr(self, _run_func_name)
        else:
            self._run_func = None
        self._code_version = "v2"
        self._directory = directory
        self.module_desc_path = os.path.join(self.directory,
                                             MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

        self._initialize()
예제 #9
0
    def serialize_to_path(self, path=None, exe=None):
        self._check_signatures()
        self._generate_desc()
        # create module path for saving
        if path is None:
            path = os.path.join(".", self.name)
        self.helper = ModuleHelper(path)
        utils.mkdir(self.helper.module_dir)

        # create module pb
        module_desc = module_desc_pb2.ModuleDesc()
        logger.info("PaddleHub version = %s" % version.hub_version)
        logger.info("PaddleHub Module proto version = %s" %
                    version.module_proto_version)
        logger.info("Paddle version = %s" % paddle.__version__)

        feeded_var_names = [
            input.name for key, sign in self.signatures.items()
            for input in sign.inputs
        ]
        target_vars = [
            output for key, sign in self.signatures.items()
            for output in sign.outputs
        ]
        feeded_var_names = list(set(feeded_var_names))
        target_vars = list(set(target_vars))

        # save inference program
        program = self.program.clone()

        for block in program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])

        if not exe:
            place = fluid.CPUPlace()
            exe = fluid.Executor(place=place)
        utils.mkdir(self.helper.model_path())
        fluid.io.save_inference_model(
            self.helper.model_path(),
            feeded_var_names=list(feeded_var_names),
            target_vars=list(target_vars),
            main_program=program,
            executor=exe)

        with open(os.path.join(self.helper.model_path(), "__model__"),
                  "rb") as file:
            program_desc_str = file.read()
            rename_program = fluid.framework.Program.parse_from_string(
                program_desc_str)
            varlist = {
                var: block
                for block in rename_program.blocks for var in block.vars
                if self.get_name_prefix() not in var
            }
            for var, block in varlist.items():
                old_name = var
                new_name = self.get_var_name_with_prefix(old_name)
                block._rename_var(old_name, new_name)
            utils.mkdir(self.helper.model_path())
            with open(
                    os.path.join(self.helper.model_path(), "__model__"),
                    "wb") as f:
                f.write(rename_program.desc.serialize_to_string())

            for file in os.listdir(self.helper.model_path()):
                if (file == "__model__" or self.get_name_prefix() in file):
                    continue
                os.rename(
                    os.path.join(self.helper.model_path(), file),
                    os.path.join(self.helper.model_path(),
                                 self.get_var_name_with_prefix(file)))

        # create processor file
        if self.processor:
            self._dump_processor()

        # create assets
        self._dump_assets()

        # create check info
        checker = ModuleChecker(self.helper.module_dir)
        checker.generate_check_info()

        # Serialize module_desc pb
        module_pb = self.desc.SerializeToString()
        with open(self.helper.module_desc_path(), "wb") as f:
            f.write(module_pb)
예제 #10
0
def create_module(directory, name, author, email, module_type, summary,
                  version):
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
            module_dir = module_dir.replace(base_dir, ".")
            tar.add(module_dir, recursive=False)
            files = []
            for dirname, _, subfiles in os.walk(module_dir):
                for file in subfiles:
                    #                     if file.startswith("."):
                    #                         continue
                    files.append(os.path.join(dirname, file))

            total_length = len(files)
            print("Create Module {}-{}".format(name, version))
            for index, file in enumerate(files):
                done = int(float(index) / total_length * 50)
                progress("[%-50s] %.2f%%" %
                         ('=' * done, float(index / total_length * 100)))
                tar.add(file)
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
            print("Module package saved as {}".format(save_file))
            os.chdir(_cwd)