Beispiel #1
0
def test_tracktor_forward(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)
    model.pretrains = None
    model.detector.pretrained = None

    from mmtrack.models import build_model
    mot = build_model(model)
    mot.eval()

    input_shape = (1, 3, 256, 256)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[10], with_track=True)
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    with torch.no_grad():
        imgs = torch.cat([imgs, imgs.clone()], dim=0)
        img_list = [g[None, :] for g in imgs]
        img2_metas = copy.deepcopy(img_metas)
        img2_metas[0]['frame_id'] = 1
        img_metas.extend(img2_metas)
        results = defaultdict(list)
        for one_img, one_meta in zip(img_list, img_metas):
            result = mot.forward([one_img], [[one_meta]], return_loss=False)
            for k, v in result.items():
                results[k].append(v)
Beispiel #2
0
def test_sot_test_forward(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)
    sot = build_model(model)
    sot.eval()

    input_shape = (1, 3, 127, 127)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[1])
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']

    with torch.no_grad():
        imgs = torch.cat([imgs, imgs.clone()], dim=0)
        img_list = [g[None, :] for g in imgs]
        img_metas.extend(copy.deepcopy(img_metas))
        for i in range(len(img_metas)):
            img_metas[i]['frame_id'] = i
        gt_bboxes.extend(copy.deepcopy(gt_bboxes))
        results = defaultdict(list)
        for one_img, one_meta, one_gt_bboxes in zip(img_list, img_metas,
                                                    gt_bboxes):
            result = sot.forward([one_img], [[one_meta]],
                                 gt_bboxes=[one_gt_bboxes],
                                 return_loss=False)
            for k, v in result.items():
                results[k].append(v)
Beispiel #3
0
def test_siamrpn_forward(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)

    sot = build_model(model)

    # Test forward train with a non-empty truth batch
    input_shape = (1, 3, 127, 127)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[1])
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']

    search_input_shape = (1, 3, 255, 255)
    search_mm_inputs = _demo_mm_inputs(search_input_shape, num_items=[1])
    search_img = search_mm_inputs.pop('imgs')[None]
    search_img_metas = search_mm_inputs.pop('img_metas')
    search_gt_bboxes = search_mm_inputs['gt_bboxes']
    img_inds = search_gt_bboxes[0].new_full((search_gt_bboxes[0].size(0), 1),
                                            0)
    search_gt_bboxes[0] = torch.cat((img_inds, search_gt_bboxes[0]), dim=1)

    losses = sot.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         search_img=search_img,
                         search_img_metas=search_img_metas,
                         search_gt_bboxes=search_gt_bboxes,
                         is_positive_pairs=[True],
                         return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = sot._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    losses = sot.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         search_img=search_img,
                         search_img_metas=search_img_metas,
                         search_gt_bboxes=search_gt_bboxes,
                         is_positive_pairs=[False],
                         return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = sot._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()
Beispiel #4
0
def init_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
    """Initialize a model from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. Default as None.
        cfg_options (dict, optional): Options to override some settings in
            the used config. Default to None.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
    if 'detector' in config.model:
        config.model.detector.pretrained = None
    model = build_model(config.model)
    # We need call `init_weights()` to load pretained weights in MOT task.
    model.init_weights()
    if checkpoint is not None:
        map_loc = 'cpu' if device == 'cpu' else None
        checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
        if 'CLASSES' in checkpoint['meta']:
            model.CLASSES = checkpoint['meta']['CLASSES']
    if not hasattr(model, 'CLASSES'):
        if hasattr(model, 'detector') and hasattr(model.detector, 'CLASSES'):
            model.CLASSES = model.detector.CLASSES
        else:
            print("Warning: The model doesn't have classes")
            model.CLASSES = None
    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model
Beispiel #5
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if hasattr(cfg.model, 'detector'):
        cfg.model.detector.pretrained = None
    cfg.data.test.test_mode = True

    # build the dataloader
    samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
    if samples_per_gpu > 1:
        # Replace 'ImageToTensor' to 'DefaultFormatBundle'
        cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(dataset,
                                   samples_per_gpu=1,
                                   workers_per_gpu=cfg.data.workers_per_gpu,
                                   dist=False,
                                   shuffle=False)

    # build the model and load checkpoint
    model = build_model(cfg.model)
    # We need call `init_weights()` to load pretained weights in MOT task.
    model.init_weights()
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
    if args.checkpoint is not None:
        load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.fuse_conv_bn:
        model = fuse_conv_bn(model)

    model = MMDataParallel(model, device_ids=[0])

    model.eval()

    # the first several iterations may be very slow so skip them
    num_warmup = 5
    pure_inf_time = 0

    # benchmark with 2000 image and take the average
    for i, data in enumerate(data_loader):

        torch.cuda.synchronize()
        start_time = time.perf_counter()

        with torch.no_grad():
            model(return_loss=False, rescale=True, **data)

        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_time

        if i >= num_warmup:
            pure_inf_time += elapsed
            if (i + 1) % args.log_interval == 0:
                fps = (i + 1 - num_warmup) / pure_inf_time
                print(f'Done image [{i + 1:<3}/ 2000], fps: {fps:.1f} img / s')

        if (i + 1) == 2000:
            pure_inf_time += elapsed
            fps = (i + 1 - num_warmup) / pure_inf_time
            print(f'Overall fps: {fps:.1f} img / s')
            break
Beispiel #6
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    if cfg.get('USE_MMDET', False):
        from mmdet.apis import train_detector as train_model
        from mmtrack.models import build_detector as build_model
        if 'detector' in cfg.model:
            cfg.model = cfg.model.detector
    else:
        from mmtrack.apis import train_model
        from mmtrack.models import build_model
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')

    # set random seeds
    if args.seed is not None:
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed

    if cfg.get('train_cfg', False):
        model = build_model(cfg.model,
                            train_cfg=cfg.train_cfg,
                            test_cfg=cfg.test_cfg)
    else:
        model = build_model(cfg.model)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmtrack version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(mmtrack_version=__version__,
                                          config=cfg.pretty_text,
                                          CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_model(model,
                datasets,
                cfg,
                distributed=distributed,
                validate=(not args.no_validate),
                timestamp=timestamp,
                meta=meta)
def test_sot_forward(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)

    from mmtrack.models import build_model
    sot = build_model(model)

    # Test forward train with a non-empty truth batch
    input_shape = (1, 3, 127, 127)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[1])
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']

    search_input_shape = (1, 3, 255, 255)
    search_mm_inputs = _demo_mm_inputs(search_input_shape, num_items=[1])
    search_img = search_mm_inputs.pop('imgs')[None]
    search_img_metas = search_mm_inputs.pop('img_metas')
    search_gt_bboxes = search_mm_inputs['gt_bboxes']
    img_inds = search_gt_bboxes[0].new_full((search_gt_bboxes[0].size(0), 1),
                                            0)
    search_gt_bboxes[0] = torch.cat((img_inds, search_gt_bboxes[0]), dim=1)

    losses = sot.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         search_img=search_img,
                         search_img_metas=search_img_metas,
                         search_gt_bboxes=search_gt_bboxes,
                         is_positive_pairs=[True],
                         return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = sot._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    losses = sot.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         search_img=search_img,
                         search_img_metas=search_img_metas,
                         search_gt_bboxes=search_gt_bboxes,
                         is_positive_pairs=[False],
                         return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = sot._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # Test forward test
    with torch.no_grad():
        imgs = torch.cat([imgs, imgs.clone()], dim=0)
        img_list = [g[None, :] for g in imgs]
        img_metas.extend(copy.deepcopy(img_metas))
        for i in range(len(img_metas)):
            img_metas[i]['frame_id'] = i
        gt_bboxes.extend(copy.deepcopy(gt_bboxes))
        results = defaultdict(list)
        for one_img, one_meta, one_gt_bboxes in zip(img_list, img_metas,
                                                    gt_bboxes):
            result = sot.forward([one_img], [[one_meta]],
                                 gt_bboxes=[one_gt_bboxes],
                                 return_loss=False)
            for k, v in result.items():
                results[k].append(v)
Beispiel #8
0
def test_vis_forward(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)

    from mmtrack.models import build_model
    vis = build_model(model)

    # Test forward train with a non-empty truth batch
    input_shape = (1, 3, 256, 256)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[10], with_track=True)
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']
    gt_labels = mm_inputs['gt_labels']
    gt_instance_ids = mm_inputs['gt_instance_ids']
    gt_masks = mm_inputs['gt_masks']

    ref_input_shape = (1, 3, 256, 256)
    ref_mm_inputs = _demo_mm_inputs(ref_input_shape,
                                    num_items=[11],
                                    with_track=True)
    ref_img = ref_mm_inputs.pop('imgs')
    ref_img_metas = ref_mm_inputs.pop('img_metas')
    ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
    ref_gt_labels = ref_mm_inputs['gt_labels']
    ref_gt_masks = ref_mm_inputs['gt_masks']
    ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids']

    losses = vis.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         gt_labels=gt_labels,
                         ref_img=ref_img,
                         ref_img_metas=ref_img_metas,
                         ref_gt_bboxes=ref_gt_bboxes,
                         ref_gt_labels=ref_gt_labels,
                         gt_instance_ids=gt_instance_ids,
                         gt_masks=gt_masks,
                         ref_gt_instance_ids=ref_gt_instance_ids,
                         ref_gt_masks=ref_gt_masks,
                         return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = vis._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # Test forward train with an empty truth batch
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[0], with_track=True)
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']
    gt_labels = mm_inputs['gt_labels']
    gt_instance_ids = mm_inputs['gt_instance_ids']
    gt_masks = mm_inputs['gt_masks']

    ref_input_shape = (1, 3, 256, 256)
    ref_mm_inputs = _demo_mm_inputs(ref_input_shape,
                                    num_items=[0],
                                    with_track=True)
    ref_img = ref_mm_inputs.pop('imgs')
    ref_img_metas = ref_mm_inputs.pop('img_metas')
    ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
    ref_gt_labels = ref_mm_inputs['gt_labels']
    ref_gt_masks = ref_mm_inputs['gt_masks']
    ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids']

    losses = vis.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         gt_labels=gt_labels,
                         ref_img=ref_img,
                         ref_img_metas=ref_img_metas,
                         ref_gt_bboxes=ref_gt_bboxes,
                         ref_gt_labels=ref_gt_labels,
                         gt_instance_ids=gt_instance_ids,
                         gt_masks=gt_masks,
                         ref_gt_instance_ids=ref_gt_instance_ids,
                         ref_gt_masks=ref_gt_masks,
                         return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = vis._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # Test forward test
    with torch.no_grad():
        imgs = torch.cat([imgs, imgs.clone()], dim=0)
        img_list = [g[None, :] for g in imgs]
        img2_metas = copy.deepcopy(img_metas)
        img2_metas[0]['frame_id'] = 1
        img_metas.extend(img2_metas)
        results = defaultdict(list)
        for one_img, one_meta in zip(img_list, img_metas):
            result = vis.forward([one_img], [[one_meta]],
                                 rescale=True,
                                 return_loss=False)
            for k, v in result.items():
                results[k].append(v)
Beispiel #9
0
def test_vid_fgfa_style_forward(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)
    model.pretrains = None
    model.detector.pretrained = None

    from mmtrack.models import build_model
    detector = build_model(model)

    # Test forward train with a non-empty truth batch
    input_shape = (1, 3, 256, 256)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    img_metas[0]['is_video_data'] = True
    gt_bboxes = mm_inputs['gt_bboxes']
    gt_labels = mm_inputs['gt_labels']
    gt_masks = mm_inputs['gt_masks']

    ref_input_shape = (2, 3, 256, 256)
    ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[9, 11])
    ref_img = ref_mm_inputs.pop('imgs')[None]
    ref_img_metas = ref_mm_inputs.pop('img_metas')
    ref_img_metas[0]['is_video_data'] = True
    ref_img_metas[1]['is_video_data'] = True
    ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
    ref_gt_labels = ref_mm_inputs['gt_labels']
    ref_gt_masks = ref_mm_inputs['gt_masks']

    losses = detector.forward(img=imgs,
                              img_metas=img_metas,
                              gt_bboxes=gt_bboxes,
                              gt_labels=gt_labels,
                              ref_img=ref_img,
                              ref_img_metas=[ref_img_metas],
                              ref_gt_bboxes=ref_gt_bboxes,
                              ref_gt_labels=ref_gt_labels,
                              gt_masks=gt_masks,
                              ref_gt_masks=ref_gt_masks,
                              return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = detector._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # Test forward train with an empty truth batch
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    img_metas[0]['is_video_data'] = True
    gt_bboxes = mm_inputs['gt_bboxes']
    gt_labels = mm_inputs['gt_labels']
    gt_masks = mm_inputs['gt_masks']

    ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[0, 0])
    ref_imgs = ref_mm_inputs.pop('imgs')[None]
    ref_img_metas = ref_mm_inputs.pop('img_metas')
    ref_img_metas[0]['is_video_data'] = True
    ref_img_metas[1]['is_video_data'] = True
    ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
    ref_gt_labels = ref_mm_inputs['gt_labels']
    ref_gt_masks = ref_mm_inputs['gt_masks']

    losses = detector.forward(img=imgs,
                              img_metas=img_metas,
                              gt_bboxes=gt_bboxes,
                              gt_labels=gt_labels,
                              ref_img=ref_imgs,
                              ref_img_metas=[ref_img_metas],
                              ref_gt_bboxes=ref_gt_bboxes,
                              ref_gt_labels=ref_gt_labels,
                              gt_masks=gt_masks,
                              ref_gt_masks=ref_gt_masks,
                              return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = detector._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # Test forward test with frame_stride=1 and frame_range=[-1,0]
    with torch.no_grad():
        imgs = torch.cat([imgs, imgs.clone()], dim=0)
        img_list = [g[None, :] for g in imgs]
        img_metas.extend(copy.deepcopy(img_metas))
        for i in range(len(img_metas)):
            img_metas[i]['frame_id'] = i
            img_metas[i]['num_left_ref_imgs'] = 1
            img_metas[i]['frame_stride'] = 1
        ref_imgs = [ref_imgs.clone(), imgs[[0]][None].clone()]
        ref_img_metas = [
            copy.deepcopy(ref_img_metas),
            copy.deepcopy([img_metas[0]])
        ]
        results = defaultdict(list)
        for one_img, one_meta, ref_img, ref_img_meta in zip(
                img_list, img_metas, ref_imgs, ref_img_metas):
            result = detector.forward([one_img], [[one_meta]],
                                      ref_img=[ref_img],
                                      ref_img_metas=[[ref_img_meta]],
                                      return_loss=False)
            for k, v in result.items():
                results[k].append(v)
Beispiel #10
0
def test_mot_forward_train(cfg_file):
    config = _get_config_module(cfg_file)
    model = copy.deepcopy(config.model)

    from mmtrack.models import build_model
    qdtrack = build_model(model)

    # Test forward train with a non-empty truth batch
    input_shape = (1, 3, 256, 256)
    mm_inputs = _demo_mm_inputs(
        input_shape, num_items=[10], num_classes=2, with_track=True)
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']
    gt_labels = mm_inputs['gt_labels']
    gt_instance_ids = mm_inputs['gt_instance_ids']
    gt_masks = mm_inputs['gt_masks']

    ref_input_shape = (1, 3, 256, 256)
    ref_mm_inputs = _demo_mm_inputs(
        ref_input_shape, num_items=[10], num_classes=2, with_track=True)
    ref_img = ref_mm_inputs.pop('imgs')
    ref_img_metas = ref_mm_inputs.pop('img_metas')
    ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
    ref_gt_labels = ref_mm_inputs['gt_labels']
    ref_gt_masks = ref_mm_inputs['gt_masks']
    ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids']

    match_tool = MatchInstances()
    gt_match_indices, _ = match_tool._match_gts(gt_instance_ids[0],
                                                ref_gt_instance_ids[0])
    gt_match_indices = [torch.tensor(gt_match_indices)]

    losses = qdtrack.forward(
        img=imgs,
        img_metas=img_metas,
        gt_bboxes=gt_bboxes,
        gt_labels=gt_labels,
        gt_masks=gt_masks,
        gt_match_indices=gt_match_indices,
        ref_img=ref_img,
        ref_img_metas=ref_img_metas,
        ref_gt_bboxes=ref_gt_bboxes,
        ref_gt_labels=ref_gt_labels,
        ref_gt_masks=ref_gt_masks,
        return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = qdtrack._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # Test forward train with an empty truth batch
    mm_inputs = _demo_mm_inputs(
        input_shape, num_items=[0], num_classes=2, with_track=True)
    imgs = mm_inputs.pop('imgs')
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']
    gt_labels = mm_inputs['gt_labels']
    gt_instance_ids = mm_inputs['gt_instance_ids']
    gt_masks = mm_inputs['gt_masks']

    ref_mm_inputs = _demo_mm_inputs(
        ref_input_shape, num_items=[0], num_classes=2, with_track=True)
    ref_img = ref_mm_inputs.pop('imgs')
    ref_img_metas = ref_mm_inputs.pop('img_metas')
    ref_gt_bboxes = ref_mm_inputs['gt_bboxes']
    ref_gt_labels = ref_mm_inputs['gt_labels']
    ref_gt_masks = ref_mm_inputs['gt_masks']
    ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids']

    gt_match_indices, _ = match_tool._match_gts(gt_instance_ids[0],
                                                ref_gt_instance_ids[0])
    gt_match_indices = [torch.tensor(gt_match_indices)]

    losses = qdtrack.forward(
        img=imgs,
        img_metas=img_metas,
        gt_bboxes=gt_bboxes,
        gt_labels=gt_labels,
        gt_masks=gt_masks,
        gt_match_indices=gt_match_indices,
        ref_img=ref_img,
        ref_img_metas=ref_img_metas,
        ref_gt_bboxes=ref_gt_bboxes,
        ref_gt_labels=ref_gt_labels,
        ref_gt_masks=ref_gt_masks,
        return_loss=True)
    assert isinstance(losses, dict)
    loss, _ = qdtrack._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()
Beispiel #11
0
def test_stark_forward():
    # test stage-1 forward
    config = _get_config_module('sot/stark/stark_st1_r50_500e_got10k.py')
    model = copy.deepcopy(config.model)

    from mmtrack.models import build_model
    sot = build_model(model)

    # Test forward train with a non-empty truth batch
    input_shape = (2, 3, 128, 128)
    mm_inputs = _demo_mm_inputs(input_shape, num_items=[1, 1])
    imgs = mm_inputs.pop('imgs')[None]
    img_metas = mm_inputs.pop('img_metas')
    gt_bboxes = mm_inputs['gt_bboxes']
    padding_mask = torch.zeros((2, 128, 128), dtype=bool)
    padding_mask[0, 100:128, 100:128] = 1
    padding_mask = padding_mask[None]

    search_input_shape = (1, 3, 320, 320)
    search_mm_inputs = _demo_mm_inputs(search_input_shape, num_items=[1])
    search_img = search_mm_inputs.pop('imgs')[None]
    search_img_metas = search_mm_inputs.pop('img_metas')
    search_gt_bboxes = search_mm_inputs['gt_bboxes']
    search_padding_mask = torch.zeros((1, 320, 320), dtype=bool)
    search_padding_mask[0, 0:20, 0:20] = 1
    search_padding_mask = search_padding_mask[None]
    img_inds = search_gt_bboxes[0].new_full((search_gt_bboxes[0].size(0), 1),
                                            0)
    search_gt_bboxes[0] = torch.cat((img_inds, search_gt_bboxes[0]), dim=1)

    losses = sot.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         padding_mask=padding_mask,
                         search_img=search_img,
                         search_img_metas=search_img_metas,
                         search_gt_bboxes=search_gt_bboxes,
                         search_padding_mask=search_padding_mask,
                         return_loss=True)
    assert isinstance(losses, dict)
    assert losses['loss_bbox'] > 0
    loss, _ = sot._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()

    # test stage-2 forward
    config = _get_config_module('sot/stark/stark_st2_r50_50e_got10k.py')
    model = copy.deepcopy(config.model)
    sot = build_model(model)
    search_gt_labels = [torch.ones((1, 2))]

    losses = sot.forward(img=imgs,
                         img_metas=img_metas,
                         gt_bboxes=gt_bboxes,
                         padding_mask=padding_mask,
                         search_img=search_img,
                         search_img_metas=search_img_metas,
                         search_gt_bboxes=search_gt_bboxes,
                         search_padding_mask=search_padding_mask,
                         search_gt_labels=search_gt_labels,
                         return_loss=True)
    assert isinstance(losses, dict)
    assert losses['loss_cls'] > 0
    loss, _ = sot._parse_losses(losses)
    loss.requires_grad_(True)
    assert float(loss.item()) > 0
    loss.backward()