Пример #1
0
def train_detector(model,
                   dataset,
                   cfg,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = Logging.getLogger()
    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
    data_loaders = [build_dataloader(ds, data=cfg.data) for ds in dataset]
    if torch.cuda.is_available():
        model = model.cuda(cfg.gpu_ids[0])
        model.device = cfg.gpu_ids[0]
        if torch.cuda.device_count() > 1:
            model = DataParallel(model, device_ids=cfg.gpu_ids)
    else:
        model.device = 'cpu'

    # build runner
    optimizer = cfg.optimizer

    if 'ema' in cfg:
        ema = cfg.ema
    else:
        ema = None
    runner = Runner(model,
                    batch_processor,
                    optimizer,
                    cfg.work_dir,
                    logger=logger,
                    meta=meta,
                    ema=ema)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp

    # register eval hooks 需要放在日志前面,不然打印不出日志。
    if validate:
        cfg.data.val.train = False
        val_dataset = build_from_dict(cfg.data.val, DATASET)
        val_dataloader = build_dataloader(val_dataset,
                                          shuffle=False,
                                          data=cfg.data)
        eval_cfg = cfg.get('evaluation', {})
        from yolodet.models.hooks.eval_hook import EvalHook
        runner.register_hook(EvalHook(val_dataloader, **eval_cfg))

    # register hooks
    # runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,cfg.checkpoint_config)
    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
Пример #2
0
        cx, cy, _w, _h = bbox
        x, y = cx - _w / 2, cy - _h / 2
        # x1,y1,x2,y2 = int(x1*w),int(y1*h),int(x2*w),int(y2*h)
        x1,y1,x2,y2 = x,y, x+_w,y+_h
        x1, y1, x2, y2 = int(x1 * w), int(y1 * h), int(x2 * w), int(y2 * h)
        # img = cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 1)
        img = cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 1)
    return img

file = '/disk2/project/pytorch-YOLOv4/cfg/dataset_test.py'

cfg = Config.fromfile(file)

dataset = build_from_dict(cfg.data.train,DATASET)

dataloader = build_dataloader(dataset,data=cfg.data)

for i, data_batch in enumerate(dataloader):
    if i>30:
        break
    for idx,data in enumerate(data_batch['img']):
        gt = data_batch['gt_bboxes'][idx]
        gt_xywh = xyxy2xywh(gt)  # x,y ,w, h
        n_gt = (gt.sum(dim=-1) > 0).sum(dim=-1)
        n = int(n_gt)
        if n == 0:
            continue
        gt = gt[:n].cpu().numpy()
        gt_xywh = gt_xywh[:n].cpu().numpy()
        data = data.cpu().numpy()*255
        data = data.transpose(1, 2, 0)
Пример #3
0
                        help='report mAP by class')
    parser.add_argument('--half',
                        action='store_true',
                        help='fp16 half precision')
    opt = parser.parse_args()

    print(opt)

    # config = '/disk2/project/mmdetection/mount/pytorch-YOLOv4/cfg/yolov5_coco_gpu.py'
    # checkpoint = '/disk2/project/pytorch-YOLOv4/work_dirs/yolov5-l_epoch_24.pth'

    cfg = Config.fromfile(opt.config)
    cfg.data.val.train = False
    val_dataset = build_from_dict(cfg.data.val, DATASET)
    val_dataloader = build_dataloader(val_dataset,
                                      data=cfg.data,
                                      shuffle=False)
    device = select_device(opt.device)
    # model = init_detector(opt.config, checkpoint=opt.checkpoint, device=device)
    model = init_detector(opt.config, checkpoint=opt.checkpoint, device=device)
    result = single_gpu_test(model,
                             val_dataloader,
                             half=opt.half,
                             conf_thres=opt.conf_thres,
                             iou_thres=opt.iou_thres,
                             merge=opt.merge,
                             save_json=opt.save_json,
                             augment=opt.augment,
                             verbose=opt.verbose,
                             coco_val_path=opt.coco_val_path)