def Load_Model(checkpoint_path, which_model='effi5', img_size=1280): if which_model == 'effi5': config = get_efficientdet_config( 'tf_efficientdet_d5') # tf_effi5 model structure elif which_model == 'effi4': config = get_efficientdet_config( 'tf_efficientdet_d4') # tf_effi4 model structure elif which_model == 'effi6': config = get_efficientdet_config( 'tf_efficientdet_d6') # tf_effi4 model structure config.image_size = (img_size, img_size) config.num_classes = 32 config.norm_kwargs = dict(eps=.001, momentum=.01) net = EfficientDet(config, pretrained_backbone=False) net.class_net = HeadNet(config, num_outputs=config.num_classes) ckp = torch.load(checkpoint_path) net.load_state_dict(ckp['model_state_dict']) del ckp net = DetBenchPredict(net) net.eval() return net
def load_net(checkpoint_path): config = get_efficientdet_config('tf_efficientdet_d7') net = EfficientDet(config, pretrained_backbone=False) config.num_classes = 1 config.image_size=512 net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01)) print(f"Load model {checkpoint_path}") checkpoint = torch.load(checkpoint_path) net.cuda() new_state_dict = OrderedDict() for k, v in checkpoint['model_state_dict'].items(): if "anchors" in k: print("Ignore: ",k) continue name = re.sub("model.",'',k) if k.startswith('model') else k new_state_dict[name] = v # net.load_state_dict(checkpoint['model_state_dict']) net.load_state_dict(new_state_dict) del checkpoint gc.collect() net = DetBenchPredict(net, config) net.eval() return net.cuda()
def load_net(checkpoint_path): config = get_efficientdet_config('tf_efficientdet_d5') net = EfficientDet(config, pretrained_backbone=False) config.num_classes = 2 config.image_size = 512 net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01)) checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint['model_state_dict']) net = DetBenchEval(net, config) net.eval() return net.cuda()
def load_net(): config = get_efficientdet_config('tf_efficientdet_d0') net = EfficientDet(config, pretrained_backbone=False) config.num_classes = 1 config.image_size=512 net.class_net = HeadNet(config, num_outputs=config.num_classes, norm_kwargs=dict(eps=.001, momentum=.01)) checkpoint = torch.load(DIR_PATH + '/models/effdet_trained.pth') net.load_state_dict(checkpoint["model_state_dict"]) del checkpoint gc.collect() net = DetBenchEval(net, config) net.eval(); return net.cuda()
def get_effdet_eval(checkpoint_path: str): config = get_efficientdet_config("tf_efficientdet_d5") model = EfficientDet(config, pretrained_backbone=False) config.num_classes = 1 config.image_size = 512 model.class_net = HeadNet( config, num_outputs=config.num_classes, norm_kwargs=dict(eps=0.001, momentum=0.01), ) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint["model_state_dict"]) del checkpoint gc.collect() model = DetBenchEval(model, config) model.eval() return model
def load_model_for_eval(checkpoint_path, variant): config = get_efficientdet_config(f"tf_efficientdet_{variant}") net = EfficientDet(config, pretrained_backbone=False) config.num_classes = 1 config.image_size = 512 net.class_net = HeadNet( config, num_outputs=config.num_classes, norm_kwargs=dict(eps=0.001, momentum=0.01), ) checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint["model_state_dict"]) del checkpoint gc.collect() net = DetBenchEval(net, config) net.eval() return net.cuda()
def set_eval_effdet(checkpoint_path: str, config, num_classes: int = 1, device: torch.device = 'cuda:0'): """Init EfficientDet to validation mode""" net = EfficientDet(config, pretrained_backbone=False) net.class_net = HeadNet(config, num_outputs=num_classes, norm_kwargs=dict(eps=.001, momentum=.01)) checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint) net = DetBenchEval(net, config) net = net.eval() return net.to_device(device)
def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model config = get_efficientdet_config(args.model) model = EfficientDet(config) if args.checkpoint: load_checkpoint(model, args.checkpoint) param_count = sum([m.numel() for m in model.parameters()]) print('Model %s created, param count: %d' % (args.model, param_count)) bench = DetBenchEval(model, config) bench = bench.cuda() if has_amp: print('Using AMP mixed precision.') bench = amp.initialize(bench, opt_level='O1') else: print('AMP not installed, running network in FP32.') if args.num_gpu > 1: bench = torch.nn.DataParallel(bench, device_ids=list(range(args.num_gpu))) if 'test' in args.anno: annotation_path = os.path.join(args.data, 'annotations', f'image_info_{args.anno}.json') image_dir = 'test2017' else: annotation_path = os.path.join(args.data, 'annotations', f'instances_{args.anno}.json') image_dir = args.anno dataset = CocoDetection(os.path.join(args.data, image_dir), annotation_path) loader = create_loader(dataset, input_size=config.image_size, batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=args.interpolation, num_workers=args.workers) img_ids = [] results = [] model.eval() batch_time = AverageMeter() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): output = bench(input, target['scale']) output = output.cpu() sample_ids = target['img_id'].cpu() for index, sample in enumerate(output): image_id = int(sample_ids[index]) for det in sample: score = float(det[4]) if score < .001: # stop when below this threshold, scores in descending order break coco_det = dict(image_id=image_id, bbox=det[0:4].tolist(), score=score, category_id=int(det[5])) img_ids.append(image_id) results.append(coco_det) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: print( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' .format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, )) json.dump(results, open(args.results, 'w'), indent=4) if 'test' not in args.anno: coco_results = dataset.coco.loadRes(args.results) coco_eval = COCOeval(dataset.coco, coco_results, 'bbox') coco_eval.params.imgIds = img_ids # score only ids we've used coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() return results