def load_model(cfg="models/mobile-yolo5l_voc.yaml",
               weights="./outputs/mvoc/weights/best_mvoc.pt"):
    restor_num = 0
    ommit_num = 0
    model = Model(cfg).to(device)
    ckpt = torch.load(weights, map_location=device)  # load checkpoint
    dic = {}
    for k, v in ckpt['model'].float().state_dict().items():
        if k in model.state_dict() and model.state_dict()[k].shape == v.shape:
            dic[k] = v
            restor_num += 1
        else:
            ommit_num += 1

    print("Build model from", cfg)
    print("Resotre weight from", weights)
    print("Restore %d vars, ommit %d vars" % (restor_num, ommit_num))

    ckpt['model'] = dic
    model.load_state_dict(ckpt['model'], strict=False)
    del ckpt

    model.float()
    model.model[-1].export = True
    return model
Ejemplo n.º 2
0
def load_checkpoint(type_,
                    weights,
                    device,
                    cfg=None,
                    hyp=None,
                    nc=None,
                    recipe=None,
                    resume=None,
                    rank=-1):
    with torch_distributed_zero_first(rank):
        attempt_download(weights)  # download if not found locally
    ckpt = torch.load(weights[0] if isinstance(weights, list)
                      or isinstance(weights, tuple) else weights,
                      map_location=device)  # load checkpoint
    start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0
    pickled = isinstance(ckpt['model'], nn.Module)
    train_type = type_ == 'train'
    ensemble_type = type_ == 'ensemble'

    if pickled and ensemble_type:
        # load ensemble using pickled
        cfg = None
        model = attempt_load(weights, map_location=device)  # load FP32 model
        state_dict = model.state_dict()
    else:
        # load model from config and weights
        cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \
              (ckpt['model'].yaml if pickled else None)
        model = Model(cfg,
                      ch=3,
                      nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc,
                      anchors=hyp.get('anchors') if hyp else None).to(device)
        model_key = 'ema' if (not train_type and 'ema' in ckpt
                              and ckpt['ema']) else 'model'
        state_dict = ckpt[model_key].float().state_dict(
        ) if pickled else ckpt[model_key]

    # turn gradients for params back on in case they were removed
    for p in model.parameters():
        p.requires_grad = True

    # load sparseml recipe for applying pruning and quantization
    recipe = recipe or (ckpt['recipe'] if 'recipe' in ckpt else None)
    sparseml_wrapper = SparseMLWrapper(model, recipe)
    exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume
    loaded = False

    if not train_type:
        # apply the recipe to create the final state of the model when not training
        sparseml_wrapper.apply()
    else:
        # intialize the recipe for training and restore the weights before if no quantized weights
        quantized_state_dict = any(
            [name.endswith('.zero_point') for name in state_dict.keys()])
        if not quantized_state_dict:
            state_dict = load_state_dict(model,
                                         state_dict,
                                         train=True,
                                         exclude_anchors=exclude_anchors)
            loaded = True
        sparseml_wrapper.initialize(start_epoch)

    if not loaded:
        state_dict = load_state_dict(model,
                                     state_dict,
                                     train=train_type,
                                     exclude_anchors=exclude_anchors)

    model.float()
    report = 'Transferred %g/%g items from %s' % (
        len(state_dict), len(model.state_dict()), weights)

    return model, {
        'ckpt': ckpt,
        'state_dict': state_dict,
        'start_epoch': start_epoch,
        'sparseml_wrapper': sparseml_wrapper,
        'report': report,
    }
Ejemplo n.º 3
0
def load_checkpoint(type_,
                    weights,
                    device,
                    cfg=None,
                    hyp=None,
                    nc=None,
                    recipe=None,
                    resume=None,
                    rank=-1):
    with torch_distributed_zero_first(rank):
        attempt_download(weights)  # download if not found locally
    ckpt = torch.load(weights, map_location=device)  # load checkpoint
    start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0
    pickled = isinstance(ckpt['model'], nn.Module)

    if pickled and type_ == 'ensemble':
        # load ensemble using pickled
        cfg = None
        model = attempt_load(weights, map_location=device)  # load FP32 model
        state_dict = model.state_dict()
    else:
        # load model from config and weights
        cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \
              (ckpt['model'].yaml if pickled else None)
        model = Model(cfg,
                      ch=3,
                      nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc,
                      anchors=hyp.get('anchors') if hyp else None).to(device)
        model_key = 'ema' if (type_ in ['ema', 'ensemble'] and 'ema' in ckpt
                              and ckpt['ema']) else 'model'
        state_dict = ckpt[model_key].float().state_dict(
        ) if pickled else ckpt[model_key]

    # turn gradients for params back on in case they were removed
    for p in model.parameters():
        p.requires_grad = True

    # load sparseml recipe for applying pruning and quantization
    recipe = recipe or (ckpt['recipe'] if 'recipe' in ckpt else None)
    sparseml_wrapper = SparseMLWrapper(model, recipe)
    if type_ in ['ema', 'ensemble']:
        # apply the recipe to create the final state of the model when not training
        sparseml_wrapper.apply()
    else:
        # intialize the recipe for training
        sparseml_wrapper.initialize(start_epoch)

    if type_ == 'train':
        # load any missing weights from the model
        exclude = [
            'anchor'
        ] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
        state_dict = intersect_dicts(state_dict,
                                     model.state_dict(),
                                     exclude=exclude)  # intersect

    model.load_state_dict(state_dict, strict=type_ != 'train')  # load
    model.float()
    report = 'Transferred %g/%g items from %s' % (
        len(state_dict), len(model.state_dict()), weights)

    return model, {
        'ckpt': ckpt,
        'state_dict': state_dict,
        'start_epoch': start_epoch,
        'sparseml_wrapper': sparseml_wrapper,
        'report': report,
    }