Exemplo n.º 1
0
def resolve_model(data_dir: Path, model_dir: str, model_cls: str, **kwargs):
    """Dynamically import class from a module in the CLASS_MAP. This is used
    to manage dependencies within nboost. For example, you don't necessarily
    want to import pytorch models everytime you boot up tensorflow..."""

    logger = set_logger('resolve_model')
    data_dir.mkdir(parents=True, exist_ok=True)
    if 'http' in model_dir:
        module = MODULE_MAP[model_cls]
        model = import_class(module, model_cls)
        return model(model_dir=model_dir)

    model_dir = data_dir.joinpath(model_dir).absolute()

    if model_dir.exists():
        logger.info('Using model cache from %s', model_dir)

        if model_dir.name in CLASS_MAP:
            model_cls = CLASS_MAP[model_dir.name]
        elif model_cls not in MODULE_MAP:
            raise ImportError('Class "%s" not in %s.' % CLASS_MAP.keys())

        module = MODULE_MAP[model_cls]
        model = import_class(module, model_cls)  # type: Type[ModelPlugin]
        return model(model_dir=str(model_dir), **kwargs)
    else:
        if model_dir.name in CLASS_MAP:
            model_cls = CLASS_MAP[model_dir.name]
            module = MODULE_MAP[model_cls]
            if model_dir.name in URL_MAP:  # DOWNLOAD AND CACHE
                url = URL_MAP[model_dir.name]
                binary_path = data_dir.joinpath(Path(url).name)

                if binary_path.exists():
                    logger.info('Found model cache in %s', binary_path)
                else:
                    logger.info('Downloading "%s" model.', model_dir)
                    download_file(url, binary_path)

                if binary_path.suffixes == ['.tar', '.gz']:
                    logger.info('Extracting "%s" from %s', model_dir,
                                binary_path)
                    extract_tar_gz(binary_path, data_dir)
            else:  # pass along to plugin maybe it can resolve it
                model_dir = model_dir.name

            model = import_class(module, model_cls)  # type: Type[ModelPlugin]
            return model(model_dir=str(model_dir), **kwargs)
        else:
            if model_cls in MODULE_MAP:
                module = MODULE_MAP[model_cls]
                model = import_class(module,
                                     model_cls)  # type: Type[ModelPlugin]
                return model(model_dir=model_dir.name, **kwargs)
            else:
                raise ImportError('model_dir %s not found in %s. You must '
                                  'set --model class to continue.' %
                                  (model_dir.name, CLASS_MAP.keys()))
Exemplo n.º 2
0
def main(argv: List[str] = None):
    parser = set_parser()
    args = vars(parser.parse_args(argv))
    indexer_class = args.pop('indexer')
    indexer_module = INDEXER_MAP[indexer_class]
    indexer = import_class(indexer_module, indexer_class)  # type: Type[BaseIndexer]
    indexer(**args).index()
Exemplo n.º 3
0
    def resolve_model(self, model_dir: Path, cls: str, **kwargs):
        """Dynamically import class from a module in the CLASS_MAP. This is used
        to manage dependencies within nboost. For example, you don't necessarily
        want to import pytorch models everytime you boot up tensorflow..."""
        if model_dir.exists():
            self.logger.info('Using model cache from %s', model_dir)

            if model_dir.name in CLASS_MAP:
                cls = CLASS_MAP[model_dir.name]
            elif cls not in MODULE_MAP:
                raise ImportError('Class "%s" not in %s.' % CLASS_MAP.keys())

            module = MODULE_MAP[cls]
            model = import_class(module, cls)
            return model(str(model_dir), **kwargs)
        else:
            if model_dir.name in CLASS_MAP:
                cls = CLASS_MAP[model_dir.name]
                module = MODULE_MAP[cls]
                url = URL_MAP[model_dir.name]
                binary_path = self.data_dir.joinpath(Path(url).name)

                if binary_path.exists():
                    self.logger.info('Found model cache in %s', binary_path)
                else:
                    self.logger.info('Downloading "%s" model.', model_dir)
                    download_file(url, binary_path)

                if binary_path.suffixes == ['.tar', '.gz']:
                    self.logger.info('Extracting "%s" from %s', model_dir,
                                     binary_path)
                    extract_tar_gz(binary_path, self.data_dir)

                model = import_class(module, cls)
                return model(str(model_dir), **kwargs)

            else:
                if cls in MODULE_MAP:
                    module = MODULE_MAP[cls]
                    model = import_class(module, cls)
                    return model(model_dir.name, **kwargs)
                else:
                    raise ImportError('model_dir %s not found in %s. You must '
                                      'set --model class to continue.' %
                                      (model_dir.name, CLASS_MAP.keys()))
Exemplo n.º 4
0
def resolve_plugin(plugin, **cli_args):
    model = import_class(MODULE_MAP[plugin], plugin)
    return model(**cli_args)