Ejemplo n.º 1
0
def Load_Model(checkpoint_path, which_model='effi5', img_size=1280):
    if which_model == 'effi5':
        config = get_efficientdet_config(
            'tf_efficientdet_d5')  # tf_effi5 model structure

    elif which_model == 'effi4':
        config = get_efficientdet_config(
            'tf_efficientdet_d4')  # tf_effi4 model structure

    elif which_model == 'effi6':
        config = get_efficientdet_config(
            'tf_efficientdet_d6')  # tf_effi4 model structure

    config.image_size = (img_size, img_size)
    config.num_classes = 32
    config.norm_kwargs = dict(eps=.001, momentum=.01)
    net = EfficientDet(config, pretrained_backbone=False)
    net.class_net = HeadNet(config, num_outputs=config.num_classes)

    ckp = torch.load(checkpoint_path)
    net.load_state_dict(ckp['model_state_dict'])
    del ckp

    net = DetBenchPredict(net)
    net.eval()
    return net
def load_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d7')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size=512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))
    print(f"Load model {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    net.cuda()
    new_state_dict = OrderedDict()
    for k, v in checkpoint['model_state_dict'].items():
        if "anchors" in k:
            print("Ignore: ",k)
            continue
        name = re.sub("model.",'',k) if k.startswith('model') else k
        new_state_dict[name] = v

    # net.load_state_dict(checkpoint['model_state_dict'])
    net.load_state_dict(new_state_dict)
    del checkpoint
    gc.collect()

    net = DetBenchPredict(net, config)
    net.eval()
    return net.cuda()
Ejemplo n.º 3
0
def load_net(checkpoint_path):
    config = get_efficientdet_config('tf_efficientdet_d5')
    net = EfficientDet(config, pretrained_backbone=False)
    config.num_classes = 2
    config.image_size = 512
    net.class_net = HeadNet(config,
                            num_outputs=config.num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))
    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint['model_state_dict'])
    net = DetBenchEval(net, config)
    net.eval()
    return net.cuda()
Ejemplo n.º 4
0
def load_net():
    config = get_efficientdet_config('tf_efficientdet_d0')
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size=512
    net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(DIR_PATH + '/models/effdet_trained.pth')
    net.load_state_dict(checkpoint["model_state_dict"])
    del checkpoint
    gc.collect()

    net = DetBenchEval(net, config)
    net.eval();
    return net.cuda()
def get_effdet_eval(checkpoint_path: str):
    config = get_efficientdet_config("tf_efficientdet_d5")
    model = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size = 512
    model.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])

    del checkpoint
    gc.collect()

    model = DetBenchEval(model, config)
    model.eval()
    return model
Ejemplo n.º 6
0
def load_model_for_eval(checkpoint_path, variant):
    config = get_efficientdet_config(f"tf_efficientdet_{variant}")
    net = EfficientDet(config, pretrained_backbone=False)

    config.num_classes = 1
    config.image_size = 512
    net.class_net = HeadNet(
        config,
        num_outputs=config.num_classes,
        norm_kwargs=dict(eps=0.001, momentum=0.01),
    )

    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint["model_state_dict"])

    del checkpoint
    gc.collect()

    net = DetBenchEval(net, config)
    net.eval()

    return net.cuda()
Ejemplo n.º 7
0
def set_eval_effdet(checkpoint_path: str,
                    config,
                    num_classes: int = 1,
                    device: torch.device = 'cuda:0'):
    """Init EfficientDet to validation mode"""
    net = EfficientDet(config, pretrained_backbone=False)
    net.class_net = HeadNet(config,
                            num_outputs=num_classes,
                            norm_kwargs=dict(eps=.001, momentum=.01))

    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint)

    net = DetBenchEval(net, config)
    net = net.eval()

    return net.to_device(device)
Ejemplo n.º 8
0
def validate(args):
    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher

    # create model
    config = get_efficientdet_config(args.model)
    model = EfficientDet(config)
    if args.checkpoint:
        load_checkpoint(model, args.checkpoint)

    param_count = sum([m.numel() for m in model.parameters()])
    print('Model %s created, param count: %d' % (args.model, param_count))

    bench = DetBenchEval(model, config)
    bench = bench.cuda()
    if has_amp:
        print('Using AMP mixed precision.')
        bench = amp.initialize(bench, opt_level='O1')
    else:
        print('AMP not installed, running network in FP32.')

    if args.num_gpu > 1:
        bench = torch.nn.DataParallel(bench,
                                      device_ids=list(range(args.num_gpu)))

    if 'test' in args.anno:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'image_info_{args.anno}.json')
        image_dir = 'test2017'
    else:
        annotation_path = os.path.join(args.data, 'annotations',
                                       f'instances_{args.anno}.json')
        image_dir = args.anno
    dataset = CocoDetection(os.path.join(args.data, image_dir),
                            annotation_path)

    loader = create_loader(dataset,
                           input_size=config.image_size,
                           batch_size=args.batch_size,
                           use_prefetcher=args.prefetcher,
                           interpolation=args.interpolation,
                           num_workers=args.workers)

    img_ids = []
    results = []
    model.eval()
    batch_time = AverageMeter()
    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(loader):
            output = bench(input, target['scale'])
            output = output.cpu()
            sample_ids = target['img_id'].cpu()
            for index, sample in enumerate(output):
                image_id = int(sample_ids[index])
                for det in sample:
                    score = float(det[4])
                    if score < .001:  # stop when below this threshold, scores in descending order
                        break
                    coco_det = dict(image_id=image_id,
                                    bbox=det[0:4].tolist(),
                                    score=score,
                                    category_id=int(det[5]))
                    img_ids.append(image_id)
                    results.append(coco_det)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.log_freq == 0:
                print(
                    'Test: [{0:>4d}/{1}]  '
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                    .format(
                        i,
                        len(loader),
                        batch_time=batch_time,
                        rate_avg=input.size(0) / batch_time.avg,
                    ))

    json.dump(results, open(args.results, 'w'), indent=4)
    if 'test' not in args.anno:
        coco_results = dataset.coco.loadRes(args.results)
        coco_eval = COCOeval(dataset.coco, coco_results, 'bbox')
        coco_eval.params.imgIds = img_ids  # score only ids we've used
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()

    return results