def load_model(cfg="models/mobile-yolo5l_voc.yaml", weights="./outputs/mvoc/weights/best_mvoc.pt"): restor_num = 0 ommit_num = 0 model = Model(cfg).to(device) ckpt = torch.load(weights, map_location=device) # load checkpoint dic = {} for k, v in ckpt['model'].float().state_dict().items(): if k in model.state_dict() and model.state_dict()[k].shape == v.shape: dic[k] = v restor_num += 1 else: ommit_num += 1 print("Build model from", cfg) print("Resotre weight from", weights) print("Restore %d vars, ommit %d vars" % (restor_num, ommit_num)) ckpt['model'] = dic model.load_state_dict(ckpt['model'], strict=False) del ckpt model.float() model.model[-1].export = True return model
def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe=None, resume=None, rank=-1): with torch_distributed_zero_first(rank): attempt_download(weights) # download if not found locally ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple) else weights, map_location=device) # load checkpoint start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0 pickled = isinstance(ckpt['model'], nn.Module) train_type = type_ == 'train' ensemble_type = type_ == 'ensemble' if pickled and ensemble_type: # load ensemble using pickled cfg = None model = attempt_load(weights, map_location=device) # load FP32 model state_dict = model.state_dict() else: # load model from config and weights cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \ (ckpt['model'].yaml if pickled else None) model = Model(cfg, ch=3, nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc, anchors=hyp.get('anchors') if hyp else None).to(device) model_key = 'ema' if (not train_type and 'ema' in ckpt and ckpt['ema']) else 'model' state_dict = ckpt[model_key].float().state_dict( ) if pickled else ckpt[model_key] # turn gradients for params back on in case they were removed for p in model.parameters(): p.requires_grad = True # load sparseml recipe for applying pruning and quantization recipe = recipe or (ckpt['recipe'] if 'recipe' in ckpt else None) sparseml_wrapper = SparseMLWrapper(model, recipe) exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume loaded = False if not train_type: # apply the recipe to create the final state of the model when not training sparseml_wrapper.apply() else: # intialize the recipe for training and restore the weights before if no quantized weights quantized_state_dict = any( [name.endswith('.zero_point') for name in state_dict.keys()]) if not quantized_state_dict: state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors) loaded = True sparseml_wrapper.initialize(start_epoch) if not loaded: state_dict = load_state_dict(model, state_dict, train=train_type, exclude_anchors=exclude_anchors) model.float() report = 'Transferred %g/%g items from %s' % ( len(state_dict), len(model.state_dict()), weights) return model, { 'ckpt': ckpt, 'state_dict': state_dict, 'start_epoch': start_epoch, 'sparseml_wrapper': sparseml_wrapper, 'report': report, }
def load_checkpoint(type_, weights, device, cfg=None, hyp=None, nc=None, recipe=None, resume=None, rank=-1): with torch_distributed_zero_first(rank): attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0 pickled = isinstance(ckpt['model'], nn.Module) if pickled and type_ == 'ensemble': # load ensemble using pickled cfg = None model = attempt_load(weights, map_location=device) # load FP32 model state_dict = model.state_dict() else: # load model from config and weights cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \ (ckpt['model'].yaml if pickled else None) model = Model(cfg, ch=3, nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc, anchors=hyp.get('anchors') if hyp else None).to(device) model_key = 'ema' if (type_ in ['ema', 'ensemble'] and 'ema' in ckpt and ckpt['ema']) else 'model' state_dict = ckpt[model_key].float().state_dict( ) if pickled else ckpt[model_key] # turn gradients for params back on in case they were removed for p in model.parameters(): p.requires_grad = True # load sparseml recipe for applying pruning and quantization recipe = recipe or (ckpt['recipe'] if 'recipe' in ckpt else None) sparseml_wrapper = SparseMLWrapper(model, recipe) if type_ in ['ema', 'ensemble']: # apply the recipe to create the final state of the model when not training sparseml_wrapper.apply() else: # intialize the recipe for training sparseml_wrapper.initialize(start_epoch) if type_ == 'train': # load any missing weights from the model exclude = [ 'anchor' ] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(state_dict, strict=type_ != 'train') # load model.float() report = 'Transferred %g/%g items from %s' % ( len(state_dict), len(model.state_dict()), weights) return model, { 'ckpt': ckpt, 'state_dict': state_dict, 'start_epoch': start_epoch, 'sparseml_wrapper': sparseml_wrapper, 'report': report, }