def download(self): """Download the model binary and cache to the package path""" # make sure data directory exists self.data_dir.mkdir(parents=True, exist_ok=True) if self.model_dir.exists(): self.logger.info('Using model cache from %s', self.model_dir) else: self.logger.info('Did not find model cache in %s', self.model_dir) if self.model_dir.name in MODEL_MAP: url = MODEL_MAP[self.model_dir.name] tar_gz_path = self.data_dir.joinpath(Path(url).name) if tar_gz_path.exists(): self.logger.info('Found model cache in %s', tar_gz_path) else: self.logger.info('Downloading "%s" model.', self.model_dir) download_file(url, tar_gz_path) self.logger.info('Extracting "%s" from %s', self.model_dir, tar_gz_path) extract_tar_gz(tar_gz_path, self.data_dir) if not self.model_dir.exists(): raise NotADirectoryError('Could not download finetuned ' 'model to "%s".' % self.model_dir) else: self.logger.warning( 'Could not find finetuned model "%s" in ' '%s. Falling back to pytorch/tf', self.model_dir, MODEL_MAP.keys())
def resolve_model(data_dir: Path, model_dir: str, model_cls: str, **kwargs): """Dynamically import class from a module in the CLASS_MAP. This is used to manage dependencies within nboost. For example, you don't necessarily want to import pytorch models everytime you boot up tensorflow...""" logger = set_logger('resolve_model') data_dir.mkdir(parents=True, exist_ok=True) if 'http' in model_dir: module = MODULE_MAP[model_cls] model = import_class(module, model_cls) return model(model_dir=model_dir) model_dir = data_dir.joinpath(model_dir).absolute() if model_dir.exists(): logger.info('Using model cache from %s', model_dir) if model_dir.name in CLASS_MAP: model_cls = CLASS_MAP[model_dir.name] elif model_cls not in MODULE_MAP: raise ImportError('Class "%s" not in %s.' % CLASS_MAP.keys()) module = MODULE_MAP[model_cls] model = import_class(module, model_cls) # type: Type[ModelPlugin] return model(model_dir=str(model_dir), **kwargs) else: if model_dir.name in CLASS_MAP: model_cls = CLASS_MAP[model_dir.name] module = MODULE_MAP[model_cls] if model_dir.name in URL_MAP: # DOWNLOAD AND CACHE url = URL_MAP[model_dir.name] binary_path = data_dir.joinpath(Path(url).name) if binary_path.exists(): logger.info('Found model cache in %s', binary_path) else: logger.info('Downloading "%s" model.', model_dir) download_file(url, binary_path) if binary_path.suffixes == ['.tar', '.gz']: logger.info('Extracting "%s" from %s', model_dir, binary_path) extract_tar_gz(binary_path, data_dir) else: # pass along to plugin maybe it can resolve it model_dir = model_dir.name model = import_class(module, model_cls) # type: Type[ModelPlugin] return model(model_dir=str(model_dir), **kwargs) else: if model_cls in MODULE_MAP: module = MODULE_MAP[model_cls] model = import_class(module, model_cls) # type: Type[ModelPlugin] return model(model_dir=model_dir.name, **kwargs) else: raise ImportError('model_dir %s not found in %s. You must ' 'set --model class to continue.' % (model_dir.name, CLASS_MAP.keys()))
def resolve_model(self, model_dir: Path, cls: str, **kwargs): """Dynamically import class from a module in the CLASS_MAP. This is used to manage dependencies within nboost. For example, you don't necessarily want to import pytorch models everytime you boot up tensorflow...""" if model_dir.exists(): self.logger.info('Using model cache from %s', model_dir) if model_dir.name in CLASS_MAP: cls = CLASS_MAP[model_dir.name] elif cls not in MODULE_MAP: raise ImportError('Class "%s" not in %s.' % CLASS_MAP.keys()) module = MODULE_MAP[cls] model = import_class(module, cls) return model(str(model_dir), **kwargs) else: if model_dir.name in CLASS_MAP: cls = CLASS_MAP[model_dir.name] module = MODULE_MAP[cls] url = URL_MAP[model_dir.name] binary_path = self.data_dir.joinpath(Path(url).name) if binary_path.exists(): self.logger.info('Found model cache in %s', binary_path) else: self.logger.info('Downloading "%s" model.', model_dir) download_file(url, binary_path) if binary_path.suffixes == ['.tar', '.gz']: self.logger.info('Extracting "%s" from %s', model_dir, binary_path) extract_tar_gz(binary_path, self.data_dir) model = import_class(module, cls) return model(str(model_dir), **kwargs) else: if cls in MODULE_MAP: module = MODULE_MAP[cls] model = import_class(module, cls) return model(model_dir.name, **kwargs) else: raise ImportError('model_dir %s not found in %s. You must ' 'set --model class to continue.' % (model_dir.name, CLASS_MAP.keys()))