def create_model(num_classes=21, device=torch.device('cpu')):
    backbone = resnet50_fpn_backbone()
    model = RetinaNet640(backbone=backbone, num_classes=num_classes)

    pre_ssd_path = "./res50_fpn_ssd640.pth"
    pre_weights_dict = torch.load(pre_ssd_path)

    # 删除类别预测器权重,注意,回归预测器的权重可以重用,因为不涉及num_classes
    del_conf_loc_dict = {}
    for k, v in pre_weights_dict.items():
        split_key = k.split(".")
        if "class_predictor" in split_key:
            continue
        del_conf_loc_dict.update({k: v})

    missing_keys, unexpected_keys = model.load_state_dict(del_conf_loc_dict,
                                                          strict=False)
    if len(missing_keys) != 0 or len(unexpected_keys) != 0:
        print("missing_keys: ", missing_keys)
        print("unexpected_keys: ", unexpected_keys)

    return model
def create_model(num_classes):
    backbone = resnet50_fpn_backbone()
    model = RetinaNet640(backbone=backbone, num_classes=num_classes)

    return model