示例#1
0
    def _load_processor(self):
        # Some module does not have a processor(e.g. ernie)
        if not 'processor_info' in self.desc:
            return

        python_path = os.path.join(self.directory, 'python')
        processor_name = self.desc.processor_info
        self.processor = utils.load_py_module(python_path, processor_name)
        self.processor = self.processor.Processor(module=self)
示例#2
0
    def load(cls, directory: str) -> Generic:
        '''Load the Module object defined in the specified directory.'''
        if directory.endswith(os.sep):
            directory = directory[:-1]

        # If the module description file existed, try to load as ModuleV1
        desc_file = os.path.join(directory, 'module_desc.pb')
        if os.path.exists(desc_file):
            return ModuleV1.load(directory)

        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
        py_module = utils.load_py_module(dirname, '{}.module'.format(basename))

        for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
            _item = py_module.__dict__[_item]
            if hasattr(_item, '_hook_by_hub') and issubclass(_item, RunModule):
                user_module_cls = _item
                break
        else:
            raise InvalidHubModule(directory)

        user_module_cls.directory = directory

        source_info_file = os.path.join(directory, '_source_info.yaml')
        if os.path.exists(source_info_file):
            info = parser.yaml_parser.parse(source_info_file)
            user_module_cls.source = info.get('source', '')
            user_module_cls.branch = info.get('branch', '')
        else:
            user_module_cls.source = ''
            user_module_cls.branch = ''

        # In the case of multiple cards, the following code can set each process to use the correct place.
        if issubclass(user_module_cls, paddle.nn.Layer):
            place = paddle.get_device().split(':')[0]
            paddle.set_device(place)

        return user_module_cls