Пример #1
0
def model(model_name: str,
          num_classes: int,
          img_size: int,
          pretrained: bool = True) -> nn.Module:
    """Creates the efficientdet model specified by `model_name`.

    The model implementation is by Ross Wightman, original repo
    [here](https://github.com/rwightman/efficientdet-pytorch).

    # Arguments
        model_name: Specifies the model to create. For pretrained models, check
            [this](https://github.com/rwightman/efficientdet-pytorch#models) table.
        num_classes: Number of classes of your dataset (including background).
        img_size: Image size that will be fed to the model. Must be squared and
            divisible by 64.
        pretrained: If True, use a pretrained backbone (on COCO).

    # Returns
        A PyTorch model.
    """
    config = get_efficientdet_config(model_name=model_name)

    net = EfficientDet(config, pretrained_backbone=False)
    if pretrained:
        if not config.url:
            raise RuntimeError(f"No pretrained weights for {model_name}")
        state_dict = torch.hub.load_state_dict_from_url(
            config.url, map_location=torch.device("cpu"))
        net.load_state_dict(state_dict)

    config.num_classes = num_classes
    config.image_size = img_size
    net.class_net = HeadNet(
        config,
        num_outputs=num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )

    # TODO: Break down param groups for backbone
    def param_groups_fn(model: nn.Module) -> List[List[nn.Parameter]]:
        unwrapped = unwrap_bench(model)

        layers = [
            unwrapped.backbone,
            unwrapped.fpn,
            nn.Sequential(unwrapped.class_net, unwrapped.box_net),
        ]
        param_groups = [list(layer.parameters()) for layer in layers]
        check_all_model_params_in_groups2(model, param_groups)

        return param_groups

    model_bench = DetBenchTrain(net, config)
    model_bench.param_groups = MethodType(param_groups_fn, model_bench)

    return model_bench
Пример #2
0
def model(model_name: str,
          num_classes: int,
          img_size: int,
          pretrained: bool = True) -> nn.Module:
    """ Creates the model specific by model_name

    Args:
        model_name (str): Specifies the model to create, available options are: TODO
        num_classes (int): Number of classes of your dataset (including background)
        pretrained (int): If True, use a pretrained backbone (on COCO)

    Returns:
          nn.Module: The requested model
    """
    config = get_efficientdet_config(model_name=model_name)

    net = EfficientDet(config, pretrained_backbone=False)
    if pretrained:
        if not config.url:
            raise RuntimeError(f"No pretrained weights for {model_name}")
        state_dict = torch.hub.load_state_dict_from_url(
            config.url, map_location=torch.device("cpu"))
        net.load_state_dict(state_dict)

    config.num_classes = num_classes
    config.image_size = img_size
    net.class_net = HeadNet(
        config,
        num_outputs=num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )

    # TODO: Break down param groups for backbone
    def param_groups_fn(model: nn.Module) -> List[List[nn.Parameter]]:
        unwrapped = unwrap_bench(model)

        layers = [
            unwrapped.backbone,
            unwrapped.fpn,
            nn.Sequential(unwrapped.class_net, unwrapped.box_net),
        ]
        param_groups = [list(layer.parameters()) for layer in layers]
        check_all_model_params_in_groups2(model, param_groups)

        return param_groups

    model_bench = DetBenchTrain(net, config)
    model_bench.param_groups = MethodType(param_groups_fn, model_bench)

    return model_bench
Пример #3
0
    def __init__(self, num_classes=11, checkpoint_path=None):
        super(EfficientDet5AP, self).__init__()
        config = get_efficientdet_config('tf_efficientdet_d5_ap')

        config.image_size = [512, 512]
        config.norm_kwargs = dict(eps=.001, momentum=.01)
        config.soft_nms = True
        config.label_smoothing = 0.1
        # config.legacy_focal = True

        net = EfficientDet(config, pretrained_backbone=False)

        if checkpoint_path == None:
            checkpoint = torch.load(
                './effdet_model/tf_efficientdet_d5_ap-3673ae5d.pth')
            net.load_state_dict(checkpoint)
            net.reset_head(num_classes=num_classes)
            net.class_net = HeadNet(config, num_outputs=config.num_classes)
            self.model = DetBenchTrain(net, config)

        else:
            checkpoint = torch.load(checkpoint_path)
            checkpoint2 = {
                '.'.join(k.split('.')[2:]): v
                for k, v in checkpoint.items()
            }
            del checkpoint2['boxes']

            net.reset_head(num_classes=num_classes)
            net.class_net = HeadNet(config, num_outputs=config.num_classes)

            net.load_state_dict(checkpoint2)
            self.model = DetBenchPredict(net)
Пример #4
0
def get_model_detection_efficientdet(model_name,
                                     num_classes,
                                     target_dim,
                                     freeze_batch_norm=False):
    print("Using EffDet detection model")

    config = effdet.get_efficientdet_config(model_name)
    efficientDetModel = EfficientDet(config, pretrained_backbone=False)
    load_pretrained(efficientDetModel, config.url)
    import omegaconf
    with omegaconf.read_write(config):
        config.num_classes = num_classes
        # config.image_size = target_dim
    efficientDetModel.class_net = HeadNet(config, num_outputs=num_classes)

    if freeze_batch_norm:
        # we only freeze BN layers in backbone and the BiFPN
        print("Freezing batch normalization weights")
        freeze_bn(efficientDetModel.backbone)

    with omegaconf.read_write(efficientDetModel.config):
        efficientDetModel.config.num_classes = num_classes

    # print(DetBenchTrain(efficientDetModel, config))
    return DetBenchTrain(efficientDetModel, config)
Пример #5
0
Файл: main.py Проект: dodler/kgl
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d0')
    net = EfficientDet(config, pretrained_backbone=True)
    config.num_classes = 14
    config.image_size = 512
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
Пример #6
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load('./input/efficientdet/efficientdet_d5-ef44aea8.pth')
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
Пример #7
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load(r'/content/gdrive/My Drive/Colab Notebooks/globalwheat/input/efficientdet-pytorch-master/efficientdet_d5-ef44aea8.pth')
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
Пример #8
0
def get_net(model_name):
    config = get_efficientdet_config(model_name)
    config.norm_kwargs = dict(eps=.001, momentum=.01)
    config.num_classes = 29
    # config.image_size = [640, 640]
    net = EfficientDet(config, pretrained_backbone=True)

    net.class_net = HeadNet(config, num_outputs=config.num_classes)
    return DetBenchTrain(net, config)
Пример #9
0
    def __init__(self, model, num_class):
        super(EfficientDetCus, self).__init__()

        config = get_efficientdet_config(f'tf_efficientdet_{model}')
        config.num_classes = num_class
        config.image_size = [TRAIN_SIZE, TRAIN_SIZE]
        model = EfficientDet(config=config, pretrained_backbone=False)
        model.class_net = HeadNet(config, num_outputs=config.num_classes)
        self.model = DetBenchTrain(model, config)
Пример #10
0
def get_net(ckpt_path = None):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    if(ckpt_path !=None):
        checkpoint = torch.load('../input/d5-pseudo-aug-best/best-checkpoint-026epoch-2.bin')
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    if(ckpt_path !=None):
        net.load_state_dict(checkpoint['model_state_dict'])
    return DetBenchTrain(net, config)
Пример #11
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load(args["weights_path"])
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = args["image_size"]
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
Пример #12
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d6')
    net = EfficientDet(config, pretrained_backbone=False)
    config.num_classes = 1
    config.image_size = 1024
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    checkpoint = torch.load('pretrained/efficientdet_d6-51cb0132.pth')
    load_state_dict(net, checkpoint)
    return DetBenchTrain(net, config)
Пример #13
0
def set_train_effdet(config,
                     num_classes: int = 1,
                     device: torch.device = 'cuda:0'):
    """Init EfficientDet to train mode"""
    model = EfficientDet(config, pretrained_backbone=False)
    model.class_net = HeadNet(config,
                              num_outputs=num_classes,
                              norm_kwargs=dict(eps=.001, momentum=.01))
    model = DetBenchTrain(model, config)
    model = model.train()

    return model.to_device(device)
Пример #14
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d6')
    net = EfficientDet(config, pretrained_backbone=False)
    config.num_classes = 1
    config.image_size = 1024
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    checkpoint = torch.load(
        'effdet6-baseline-1024-4x8-sa-fold0/best-checkpoint-052epoch.bin')
    load_state_dict(net, checkpoint['model_state_dict'])

    return DetBenchTrain(net, config)
Пример #15
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load(
        '/home/eragon/Documents/scripts/efficientdet-pytorch/checkpoints/tf_efficientdet_d5_51-c79f9be6.pth'
    )
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
def get_effdet_train(model_path: str):
    config = get_efficientdet_config("tf_efficientdet_d5")
    net = EfficientDet(config, pretrained_backbone=False)
    checkpoint = torch.load(model_path)
    config.num_classes = 1
    config.image_size = 512
    net.load_state_dict(checkpoint)
    net.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )
    return DetBenchTrain(net, config).train()
Пример #17
0
def get_model_instance_segmentation_efficientnet(model_name,
                                                 num_classes,
                                                 target_dim,
                                                 freeze_batch_norm=False):
    print("Using EffDet detection model")

    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                    output_size=7,
                                                    sampling_ratio=2)
    # ofekp: note that roi_pooler is passed to box_roi_pooler in the MaskRCNN network
    # and is not being used in roi_heads.py

    mask_roi_pool = MultiScaleRoIAlign(featmap_names=[0, 1, 2, 3],
                                       output_size=14,
                                       sampling_ratio=2)

    config = effdet.get_efficientdet_config(model_name)
    efficientDetModelTemp = EfficientDet(config, pretrained_backbone=False)
    load_pretrained(efficientDetModelTemp, config.url)
    config.num_classes = num_classes
    config.image_size = target_dim

    out_channels = config.fpn_channels  # This is since the config of 'tf_efficientdet_d5' creates fpn outputs with num of channels = 288
    backbone_fpn = BackboneWithCustomFPN(
        config, efficientDetModelTemp.backbone, efficientDetModelTemp.fpn,
        out_channels
    )  # TODO(ofekp): pretrained! # from the repo trainable_layers=trainable_backbone_layers=3
    model = MaskRCNN(
        backbone_fpn,
        min_size=target_dim,
        max_size=target_dim,
        num_classes=num_classes,
        mask_roi_pool=mask_roi_pool,
        #                  rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler)

    # for training with different number of classes (default is 90) we need to add this line
    # TODO(ofekp): we might want to init weights of the new HeadNet
    class_net = HeadNet(config,
                        num_outputs=config.num_classes,
                        norm_kwargs=dict(eps=.001, momentum=.01))
    efficientDetModel = EfficientDetBB(config, class_net,
                                       efficientDetModelTemp.box_net)
    model.roi_heads.box_predictor = DetBenchTrain(efficientDetModel, config)

    if freeze_batch_norm:
        # we only freeze BN layers in backbone and the BiFPN
        print("Freezing batch normalization weights")
        freeze_bn(model.backbone)

    return model
Пример #18
0
def load_train_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size = 1024
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint['model_state_dict'])

    return DetBenchTrain(net, config)
Пример #19
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d5')

    net = EfficientDet(config, pretrained_backbone=False)
    # config.num_classes = 1
    # config.image_size = 512
    # net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    # checkpoint = torch.load(r'D:\Workspace\efficientdet-pytorch-master\effdet5-cutmix-augmix\last-checkpoint.bin')
    # net.load_state_dict(checkpoint['model_state_dict'])
    checkpoint = torch.load(r'C:\Users\Aministration\Downloads\efficientdet_d5-ef44aea8.pth')
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
Пример #20
0
def get_net():
    config = get_efficientdet_config('tf_efficientdet_d7')

    net = EfficientDet(config, pretrained_backbone=False)
#     config.num_classes = 1
#     config.image_size = 512
#     net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
#     checkpoint = torch.load(r'./effdet5-cutmix-augmix1/last-checkpoint-folder1.bin')
#     net.load_state_dict(checkpoint['model_state_dict'])
    checkpoint = torch.load(r'./tf_efficientdet_d7_53-6d1d7a95.pth')
    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
Пример #21
0
def get_net():
    # 模型的配置,这个返回的是一个字典
    config = get_efficientdet_config('tf_efficientdet_d5')
    config.num_classes = 1
    config.image_size = 512
    # 根据上面的配置生成网络
    net = EfficientDet(config, pretrained_backbone=False)
    # 加载gwd预训练模型
    checkpoint = torch.load(r'../input/eff_checkpoint/bestv2.bin')
    net.load_state_dict(checkpoint['model_state_dict'])

    # 加載coco預訓練模型

    # norm_kwargs 设置的是 BATCHNORM2D 的参数
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    return DetBenchTrain(net, config)
def load_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    net.load_state_dict(checkpoint['model_state_dict'])

    del checkpoint
    gc.collect()

    net = DetBenchTrain(net, config)
    return net.cuda()
Пример #23
0
def get_train_model(config_name='tf_efficientdet_d0', model_ckpt=None):
    config = get_efficientdet_config(config_name)
    model = EfficientDet(config, pretrained_backbone=True)

    #layers_to_train = ['blocks.6','blocks.5', 'blocks.4', 'blocks.3', 'blocks.2', 'blocks.1', 'conv_stem'][:5]
    #for name, parameter in model.backbone.named_parameters():
    #    if all([not name.startswith(layer) for layer in layers_to_train]):
    #        parameter.requires_grad_(False)
    #        print(f'layer {name} frozen')

    config.num_classes = 1
    config.image_size = IMG_SIZE
    model.class_net = HeadNet(config,
                              num_outputs=config.num_classes,
                              norm_kwargs=dict(eps=.001, momentum=.01))
    if model_ckpt is not None:
        model.load_state_dict(torch.load(model_ckpt)['model_state_dict'])
    return DetBenchTrain(model, config)
def get_effdet_train_hotstart(checkpoint_path: str):
    config = get_efficientdet_config("tf_efficientdet_d5")
    model = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size = 512
    model.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])

    del checkpoint
    gc.collect()

    model = DetBenchTrain(model, config).train()
    return model
Пример #25
0
def get_train_net(eval_net):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))

    # checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(eval_net.model.state_dict())

    # del checkpoint
    gc.collect()

    net = DetBenchTrain(net, config)
    net = net.train()
    return net.cuda()
Пример #26
0
def get_model_(variant, model_dir):
    config = get_efficientdet_config(f"tf_efficientdet_{variant}")
    net = EfficientDet(config, pretrained_backbone=False)

    # checkpoint = torch.load('/content/wheat_efficientdet/pretrained_models/efficientdet_d0-d92fd44f.pth')

    if variant == "d0":
        checkpoint_path = f"{model_dir}/efficientdet_d0-d92fd44f.pth"
    checkpoint = torch.load(checkpoint_path)

    net.load_state_dict(checkpoint)
    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )
    return DetBenchTrain(net, config)
Пример #27
0
def get_model(variant, model_dir, load_path):

    config = get_efficientdet_config(f"tf_efficientdet_{variant}")
    net = EfficientDet(config, pretrained_backbone=False)

    if not load_path:
        print("here")
        if variant == "d0":
            checkpoint_path = f"{model_dir}/efficientdet_d0-d92fd44f.pth"
        elif variant == "d1":
            checkpoint_path = f"{model_dir}/efficientdet_d1-4c7ebaf2.pth"
        elif variant == "d2":
            checkpoint_path = f"{model_dir}/efficientdet_d2-cb4ce77d.pth"
        elif variant == "d3":
            checkpoint_path = f"{model_dir}/efficientdet_d3-b0ea2cbc.pth"
        elif variant == "d4":
            checkpoint_path = f"{model_dir}/efficientdet_d4-5b370b7a.pth"
        elif variant == "d5":
            checkpoint_path = f"{model_dir}/efficientdet_d5-ef44aea8.pth"
        elif variant == "d6":
            checkpoint_path = f"{model_dir}/efficientdet_d6-51cb0132.pth"
        elif variant == "d6":
            checkpoint_path = f"{model_dir}/efficientdet_d7-f05bf714.pth"

        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint)

    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )

    if load_path:
        print("here")
        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint["model_state_dict"])

    return DetBenchTrain(net, config)
Пример #28
0
def create_model(num_classes=1,
                 image_size=512,
                 architecture="tf_efficientnetv2_l"):
    efficientdet_model_param_dict["tf_efficientnetv2_l"] = dict(
        name="tf_efficientnetv2_l",
        backbone_name="tf_efficientnetv2_l",
        backbone_args=dict(drop_path_rate=0.2),
        num_classes=num_classes,
        url="",
    )

    config = get_efficientdet_config(architecture)
    config.update({"num_classes": num_classes})
    config.update({"image_size": (image_size, image_size)})

    net = EfficientDet(config, pretrained_backbone=True)
    net.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
    )
    return DetBenchTrain(net, config)
Пример #29
0
    def __init__(self, num_classes=11, checkpoint=None):
        super(EfficientDet6, self).__init__()
        config = get_efficientdet_config('tf_efficientdet_d6')

        config.image_size = [512, 512]
        config.norm_kwargs = dict(eps=.001, momentum=.01)
        config.soft_nms = True
        config.label_smoothing = 0.1
        config.mean = [0.46009655, 0.43957878, 0.41827092]
        config.std = [0.2108204, 0.20766491, 0.21656131]

        net = EfficientDet(config, pretrained_backbone=False)
        if checkpoint == None:
            checkpoint = torch.load(
                './effdet_model/tf_efficientdet_d6_52-4eda3773.pth')

        net.load_state_dict(checkpoint)

        net.reset_head(num_classes=num_classes)
        net.class_net = HeadNet(config, num_outputs=config.num_classes)

        self.model = DetBenchTrain(net, config)
def get_train_model(config_name,
                    img_size,
                    model_ckpt=None,
                    useGN=False,
                    light=True):
    config = get_efficientdet_config(config_name)
    model = EfficientDet(config, pretrained_backbone=True)
    config.num_classes = 1
    config.image_size = img_size
    model.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )
    if useGN is True:
        model = convert_layers(model,
                               nn.BatchNorm2d,
                               nn.GroupNorm,
                               True,
                               num_groups=2)
    # Replace BatchNorm with GroupNorm
    if model_ckpt is not None:
        if light is True:
            count = 0
            state_dict = torch.load(model_ckpt)["state_dict"]
            new_state_dict = OrderedDict()
            for key, value in state_dict.items():
                if key.startswith("model.model."):
                    new_key = reduce(lambda a, b: a + "." + b,
                                     key.split(".")[2:])
                    if new_key in model.state_dict():
                        new_state_dict[new_key] = value
                        count += 1
            model.load_state_dict(new_state_dict)
            print(f"loaded {count} keys")
        else:
            model.load_state_dict(torch.load(model_ckpt)["state_dict"])

    return DetBenchTrain(model, config)