def test_load_model_from_stub(stub, model_args, other_args): model = Zoo.load_model_from_stub(stub, **other_args) for key in model_args: if key and hasattr(model, key): assert getattr(model, key) == model_args[key] shutil.rmtree(model.dir_path)
def load_data(data_path: str, ) -> List[List[numpy.ndarray]]: """ Loads data from given sparseZoo stub or directory with .npz files :param data_path: directory path to .npz files to load or SparseZoo stub :return: List of loaded npz files """ if data_path.startswith("zoo:"): data_dir = Zoo.load_model_from_stub( data_path).data_inputs.downloaded_path() else: data_dir = data_path data_files = os.listdir(data_dir) if any(".npz" not in file_name for file_name in data_files): raise RuntimeError( f"All files in data directory {data_dir} must have a .npz extension " f"found {[name for name in data_files if '.npz' not in name]}") samples = load_numpy_list(data_dir) # unwrap unloaded numpy files samples = [ load_numpy(sample) if isinstance(sample, str) else sample for sample in samples ] processed_samples = [] for idx, sample in enumerate(samples): sample = list(sample.values()) processed_samples.append(sample) return processed_samples
def model_to_path(model: Union[str, Model, File]) -> str: """ Deals with the various forms a model can take. Either an ONNX file, a SparseZoo model stub prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File object that defines the neural network """ if not model: raise ValueError( "model must be a path, sparsezoo.Model, or sparsezoo.File") if isinstance(model, str) and model.startswith("zoo:"): # load SparseZoo Model from stub if sparsezoo_import_error is not None: raise sparsezoo_import_error model = Zoo.load_model_from_stub(model) if Model is not object and isinstance(model, Model): # default to the main onnx file for the model model = model.onnx_file.downloaded_path() elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system model = model.downloaded_path() if not isinstance(model, str): raise ValueError("unsupported type for model: {}".format(type(model))) if not os.path.exists(model): raise ValueError("model path must exist: given {}".format(model)) return model
def modify_yolo_onnx_input_shape( model_path: str, image_shape: Tuple[int] ) -> Tuple[str, Optional[NamedTemporaryFile]]: """ Creates a new YOLOv3 ONNX model from the given path that accepts the given input shape. If the given model already has the given input shape no modifications are made. Uses a tempfile to store the modified model file. :param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model to be loaded :param image_shape: 2-tuple of the image shape to resize this yolo model to :return: filepath to an onnx model reshaped to the given input shape will be the original path if the shape is the same. Additionally returns the NamedTemporaryFile for managing the scope of the object for file deletion """ original_model_path = model_path if model_path.startswith("zoo:"): # load SparseZoo Model from stub model = Zoo.load_model_from_stub(model_path) model_path = model.onnx_file.downloaded_path() print(f"Downloaded {original_model_path} to {model_path}") model = onnx.load(model_path) model_input = model.graph.input[0] initial_x = get_tensor_dim_shape(model_input, 2) initial_y = get_tensor_dim_shape(model_input, 3) if not (isinstance(initial_x, int) and isinstance(initial_y, int)): return model_path, None # model graph does not have static integer input shape if (initial_x, initial_y) == tuple(image_shape): return model_path, None # no shape modification needed scale_x = initial_x / image_shape[0] scale_y = initial_y / image_shape[1] set_tensor_dim_shape(model_input, 2, image_shape[0]) set_tensor_dim_shape(model_input, 3, image_shape[1]) for model_output in model.graph.output: output_x = get_tensor_dim_shape(model_output, 2) output_y = get_tensor_dim_shape(model_output, 3) set_tensor_dim_shape(model_output, 2, int(output_x / scale_x)) set_tensor_dim_shape(model_output, 3, int(output_y / scale_y)) tmp_file = NamedTemporaryFile() # file will be deleted after program exit, print( f"Overwriting original model shape {(initial_x, initial_y)} to {image_shape}\n" f"Original model path: {original_model_path}, new temporary model saved to " f"{}" ) return, tmp_file
def _load_model(args) -> Tuple[Any, List[str]]: if args.engine == ORT_ENGINE and ort_error is not None: raise ort_error # validation if (args.num_cores is not None and args.engine == ORT_ENGINE and onnxruntime.__version__ < "1.7"): raise ValueError( "overriding default num_cores not supported for onnxruntime < 1.7.0. " "If using an older build with OpenMP, try setting the OMP_NUM_THREADS " "environment variable") # load model from sparsezoo if necessary if args.model_filepath.startswith("zoo:"): zoo_model = Zoo.load_model_from_stub(args.model_filepath) downloaded_path = zoo_model.onnx_file.downloaded_path() print( f"downloaded sparsezoo model {args.model_filepath} to {downloaded_path}" ) args.model_filepath = downloaded_path # scale static ONNX graph to desired image shape input_names = [] if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]: args.model_filepath, input_names, _ = overwrite_transformer_onnx_model_inputs( args.model_filepath, batch_size=args.batch_size, max_length=args.max_sequence_length, ) # load model if args.engine == DEEPSPARSE_ENGINE: print(f"Compiling deepsparse model for {args.model_filepath}") model = compile_model(args.model_filepath, args.batch_size, args.num_cores) print(f"Engine info: {model}") elif args.engine == ORT_ENGINE: print(f"loading onnxruntime model for {args.model_filepath}") sess_options = onnxruntime.SessionOptions() if args.num_cores is not None: sess_options.intra_op_num_threads = args.num_cores sess_options.log_severity_level = 3 sess_options.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL) model = onnxruntime.InferenceSession(args.model_filepath, sess_options=sess_options) return model, input_names
def test_search_sparse_recipes_from_stub(model_stub, other_args): model = Zoo.load_model_from_stub(model_stub, **other_args) recipes = Zoo.search_sparse_recipes(model_stub) assert len(recipes) > 0 for recipe in recipes: assert recipe assert recipe.model_metadata.domain == model.domain assert recipe.model_metadata.sub_domain == model.sub_domain assert recipe.model_metadata.architecture == model.architecture assert recipe.model_metadata.sub_architecture == model.sub_architecture assert recipe.model_metadata.framework == model.framework assert recipe.model_metadata.repo == model.repo assert recipe.model_metadata.dataset == model.dataset assert recipe.model_metadata.training_scheme == model.training_scheme
def get_onnx_path_and_configs( model_path: str, ) -> Tuple[str, Optional[str], Optional[str]]: """ :param model_path: path to onnx file, transformers sparsezoo stub, or directory containing `model.onnx`, `config.json`, and/or `tokenizer.json` files. If no `model.onnx` file is found in a model directory, an exception will be raised :return: tuple of ONNX file path, parent directory of config file if it exists, and parent directory of tokenizer config file if it exists. (Parent directories returned instead of absolute path for compatibility with transformers .from_pretrained() method) """ if os.path.isfile(model_path): return model_path, None, None config_path = None tokenizer_path = None if os.path.isdir(model_path): model_files = os.listdir(model_path) if _MODEL_DIR_ONNX_NAME not in model_files: raise ValueError( f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory " f"{model_path}. Be sure that an export of the model is written to " f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}") onnx_path = os.path.join(model_path, _MODEL_DIR_ONNX_NAME) if _MODEL_DIR_CONFIG_NAME in model_files: config_path = model_path if _MODEL_DIR_TOKENIZER_NAME in model_files: tokenizer_path = model_path elif model_path.startswith("zoo:"): zoo_model = Zoo.load_model_from_stub(model_path) onnx_path = zoo_model.onnx_file.downloaded_path() for framework_file in zoo_model.framework_files: if framework_file.display_name == _MODEL_DIR_CONFIG_NAME: config_path = _get_file_parent( framework_file.downloaded_path()) if "tokenizer" in framework_file.display_name: tokenizer_path = _get_file_parent( framework_file.downloaded_path()) else: raise ValueError( f"model_path {model_path} is not a valid file, directory, or zoo stub" ) return onnx_path, config_path, tokenizer_path
def fix_onnx_input_shape( model_path: str, image_shape: Optional[Tuple[int]], ) -> Tuple[str, Optional[NamedTemporaryFile]]: """ Creates a new ONNX model from the given path that accepts the given input shape. If the given model already has the given input shape no modifications are made. Uses a tempfile to store the modified model file. :param model_path: file path to ONNX model or SparseZoo stub of the model to be loaded :param image_shape: 2-tuple of the image shape to resize this model to, or None if no resizing needed :return: filepath to an onnx model reshaped to the given input shape will be the original path if the shape is the same. Additionally returns the NamedTemporaryFile for managing the scope of the object for file deletion. Additionally returns the image-shape to benchmark the new model with. """ original_model_path = model_path if model_path.startswith("zoo:"): # load SparseZoo Model from stub model = Zoo.load_model_from_stub(model_path) model_path = model.onnx_file.downloaded_path() print(f"Downloaded {original_model_path} to {model_path}") model = onnx.load(model_path) model_input = model.graph.input[0] original_x = get_tensor_dim_shape(model_input, 2) original_y = get_tensor_dim_shape(model_input, 3) original_image_shape = (original_x, original_y) if image_shape is None or original_image_shape == tuple(image_shape): return model_path, None, original_image_shape # no shape modification needed set_tensor_dim_shape(model_input, 2, image_shape[0]) set_tensor_dim_shape(model_input, 3, image_shape[1]) tmp_file = NamedTemporaryFile() # file will be deleted after program exit, print( f"Overwriting original model shape {original_image_shape} to {image_shape}\n" f"Original model path: {original_model_path}, new temporary model saved to " f"{}") return, tmp_file, image_shape
def _load_data(args, input_names) -> List[List[numpy.ndarray]]: if args.data_path.startswith("zoo:"): data_dir = Zoo.load_model_from_stub( args.data_path).data_inputs.downloaded_path() else: data_dir = args.data_path data_files = os.listdir(data_dir) if any(".npz" not in file_name for file_name in data_files): raise RuntimeError( f"All files in data directory {data_dir} must have a .npz extension " f"found {[name for name in data_files if '.npz' not in name]}") samples = load_numpy_list(data_dir) # unwrap unloaded numpy files samples = [ load_numpy(sample) if isinstance(sample, str) else sample for sample in samples ] processed_samples = [] warning_given = False for sample in samples: if not all(inp_name in sample for inp_name in input_names) or len(input_names) != len(sample): if not warning_given: warnings.warn( "input sample found whose input names do not match the model input " "names, this may cause an exception during benchmarking") warning_given = True sample = list(sample.values()) else: sample = [sample[inp_name] for inp_name in input_names] for idx, array in enumerate(sample): processed_array = numpy.zeros( [args.max_sequence_length, *array.shape[1:]], dtype=array.dtype, ) if array.shape[0] < args.max_sequence_length: processed_array[:array.shape[0], ...] = array else: processed_array[:, ...] = array[:args.max_sequence_length, ...] sample[idx] = processed_array processed_samples.append(sample) return processed_samples
def main(): setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank()'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else:'Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # for backwards compat, `--amp` arg tries apex before native amp if has_apex: args.apex_amp = True elif has_native_amp: args.native_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") torch.manual_seed(args.seed + args.rank) #################################################################################### # Start - SparseML optional load weights from SparseZoo #################################################################################### if args.initial_checkpoint == "zoo": # Load checkpoint from base weights associated with given SparseZoo recipe if args.sparseml_recipe.startswith("zoo:"): args.initial_checkpoint = Zoo.download_recipe_base_framework_files( args.sparseml_recipe, extensions=[".pth.tar", ".pth"] )[0] else: raise ValueError( "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When initial-checkpoint is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " f"stub. sparseml-recipe was set to {args.sparseml_recipe}" ) elif args.initial_checkpoint.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint) args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"]) #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block,, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0:'Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp != 'native': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer(args, model) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0:'Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0:'Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0:'AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp != 'native': # Apex DDP preferred unless native amp is activated if args.local_rank == 0:"Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0:"Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() elif mixup_active: # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join(["%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) #################################################################################### # Start SparseML Integration #################################################################################### sparseml_loggers = ( [PythonLogger(), TensorBoardLogger(log_path=output_dir)] if output_dir else None ) manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe) optimizer = ScheduledOptimizer( optimizer, model, manager, steps_per_epoch=len(loader_train), loggers=sparseml_loggers ) # override lr scheduler if recipe makes any LR updates if any("LearningRate" in str(modifier) for modifier in manager.modifiers):"Disabling timm LR scheduler, managing LR using SparseML recipe") lr_scheduler = None if manager.max_epochs: f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe" ) num_epochs = manager.max_epochs or num_epochs #################################################################################### # End SparseML Integration #################################################################################### if args.local_rank == 0:'Scheduled epochs: {}'.format(num_epochs)) try: for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0:"Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) ################################################################################# # Start SparseML ONNX Export ################################################################################# if output_dir: f"training complete, exporting ONNX to {output_dir}/model.onnx" ) exporter = ModuleExporter(model, output_dir) exporter.export_onnx(torch.randn((1, *data_config["input_size"]))) ################################################################################# # End SparseML ONNX Export ################################################################################# except KeyboardInterrupt: pass if best_metric is not None:'*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
from sparsezoo import Zoo if opt.weights == "zoo": # Load checkpoint from base weights associated with given SparseZoo recipe if opt.sparseml_recipe.startswith("zoo:"): opt.weights = Zoo.download_recipe_base_framework_files( opt.sparseml_recipe, extensions=[".pt", ".pth"])[0] else: raise ValueError( "Attempting to load weights from SparseZoo recipe, but not given a " "SparseZoo recipe stub. When --weights is set to 'zoo'. " "sparseml-recipe must start with 'zoo:' and be a SparseZoo model " f"stub. sparseml-recipe was set to {opt.sparseml_recipe}") elif opt.weights.startswith("zoo:"): # Load weights from a SparseZoo model stub zoo_model = Zoo.load_model_from_stub(opt.weights) opt.weights = zoo_model.download_framework_files( extensions=[".pt", ".pth"])[0] #################################################################################### # End - SparseML optional load weights from SparseZoo #################################################################################### # Resume if opt.resume: # resume an interrupted run ckpt = opt.resume if isinstance( opt.resume, str) else get_latest_run() # specified or most recent path assert os.path.isfile( ckpt), 'ERROR: --resume checkpoint does not exist' apriori = opt.global_rank, opt.local_rank with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
def load_model( path: str, model: Module, strict: bool = False, ignore_error_tensors: List[str] = None, fix_data_parallel: bool = True, ): """ Load the state dict into a model from a given file. :param path: the path to the pth file to load the state dict from. May also be a SparseZoo stub path preceded by 'zoo:' with the optional `?recipe_type=` argument. If given a recipe type, the base model weights for that recipe will be loaded. :param model: the model to load the state dict into :param strict: True to enforce that all tensors match between the model and the file; False otherwise :param ignore_error_tensors: names of tensors to ignore if they are not found in either the model or the file :param fix_data_parallel: fix the keys in the model state dict if they look like they came from DataParallel type setup (start with module.). This removes "module." all keys """ if path.startswith("zoo:"): if "recipe_type=" in path: path = Zoo.download_recipe_base_framework_files( path, extensions=[".pth"])[0] else: path = Zoo.load_model_from_stub(path).download_framework_files( extensions=[".pth"])[0] model_dict = torch.load(path, map_location="cpu") current_dict = model.state_dict() if "state_dict" in model_dict: model_dict = model_dict["state_dict"] # check if any keys were saved through DataParallel type setup and convert those if fix_data_parallel: keys = [k for k in model_dict.keys()] module_key = "module." for key in keys: if key.startswith(module_key): new_key = key[len(module_key):] model_dict[new_key] = model_dict[key] del model_dict[key] if not ignore_error_tensors: ignore_error_tensors = [] for ignore in ignore_error_tensors: if ignore not in model_dict and ignore not in current_dict: continue if (ignore in model_dict and ignore in current_dict and current_dict[ignore].shape != model_dict[ignore].shape): model_dict[ignore] = current_dict[ignore] elif ignore not in model_dict and ignore in current_dict: model_dict[ignore] = current_dict[ignore] elif ignore in model_dict and ignore not in current_dict: del model_dict[ignore] model.load_state_dict(model_dict, strict)