def test_tracktor_forward(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) model.pretrains = None model.detector.pretrained = None from mmtrack.models import build_model mot = build_model(model) mot.eval() input_shape = (1, 3, 256, 256) mm_inputs = _demo_mm_inputs(input_shape, num_items=[10], with_track=True) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') with torch.no_grad(): imgs = torch.cat([imgs, imgs.clone()], dim=0) img_list = [g[None, :] for g in imgs] img2_metas = copy.deepcopy(img_metas) img2_metas[0]['frame_id'] = 1 img_metas.extend(img2_metas) results = defaultdict(list) for one_img, one_meta in zip(img_list, img_metas): result = mot.forward([one_img], [[one_meta]], return_loss=False) for k, v in result.items(): results[k].append(v)
def test_sot_test_forward(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) sot = build_model(model) sot.eval() input_shape = (1, 3, 127, 127) mm_inputs = _demo_mm_inputs(input_shape, num_items=[1]) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] with torch.no_grad(): imgs = torch.cat([imgs, imgs.clone()], dim=0) img_list = [g[None, :] for g in imgs] img_metas.extend(copy.deepcopy(img_metas)) for i in range(len(img_metas)): img_metas[i]['frame_id'] = i gt_bboxes.extend(copy.deepcopy(gt_bboxes)) results = defaultdict(list) for one_img, one_meta, one_gt_bboxes in zip(img_list, img_metas, gt_bboxes): result = sot.forward([one_img], [[one_meta]], gt_bboxes=[one_gt_bboxes], return_loss=False) for k, v in result.items(): results[k].append(v)
def test_siamrpn_forward(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) sot = build_model(model) # Test forward train with a non-empty truth batch input_shape = (1, 3, 127, 127) mm_inputs = _demo_mm_inputs(input_shape, num_items=[1]) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] search_input_shape = (1, 3, 255, 255) search_mm_inputs = _demo_mm_inputs(search_input_shape, num_items=[1]) search_img = search_mm_inputs.pop('imgs')[None] search_img_metas = search_mm_inputs.pop('img_metas') search_gt_bboxes = search_mm_inputs['gt_bboxes'] img_inds = search_gt_bboxes[0].new_full((search_gt_bboxes[0].size(0), 1), 0) search_gt_bboxes[0] = torch.cat((img_inds, search_gt_bboxes[0]), dim=1) losses = sot.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, search_img=search_img, search_img_metas=search_img_metas, search_gt_bboxes=search_gt_bboxes, is_positive_pairs=[True], return_loss=True) assert isinstance(losses, dict) loss, _ = sot._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() losses = sot.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, search_img=search_img, search_img_metas=search_img_metas, search_gt_bboxes=search_gt_bboxes, is_positive_pairs=[False], return_loss=True) assert isinstance(losses, dict) loss, _ = sot._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward()
def init_model(config, checkpoint=None, device='cuda:0', cfg_options=None): """Initialize a model from config file. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. Default as None. cfg_options (dict, optional): Options to override some settings in the used config. Default to None. Returns: nn.Module: The constructed detector. """ if isinstance(config, str): config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') if cfg_options is not None: config.merge_from_dict(cfg_options) if 'detector' in config.model: config.model.detector.pretrained = None model = build_model(config.model) # We need call `init_weights()` to load pretained weights in MOT task. model.init_weights() if checkpoint is not None: map_loc = 'cpu' if device == 'cpu' else None checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) if 'CLASSES' in checkpoint['meta']: model.CLASSES = checkpoint['meta']['CLASSES'] if not hasattr(model, 'CLASSES'): if hasattr(model, 'detector') and hasattr(model.detector, 'CLASSES'): model.CLASSES = model.detector.CLASSES else: print("Warning: The model doesn't have classes") model.CLASSES = None model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model
def main(): args = parse_args() cfg = Config.fromfile(args.config) # import modules from string list. if cfg.get('custom_imports', None): from mmcv.utils import import_modules_from_strings import_modules_from_strings(**cfg['custom_imports']) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True if hasattr(cfg.model, 'detector'): cfg.model.detector.pretrained = None cfg.data.test.test_mode = True # build the dataloader samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) if samples_per_gpu > 1: # Replace 'ImageToTensor' to 'DefaultFormatBundle' cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) dataset = build_dataset(cfg.data.test) data_loader = build_dataloader(dataset, samples_per_gpu=1, workers_per_gpu=cfg.data.workers_per_gpu, dist=False, shuffle=False) # build the model and load checkpoint model = build_model(cfg.model) # We need call `init_weights()` to load pretained weights in MOT task. model.init_weights() fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) if args.checkpoint is not None: load_checkpoint(model, args.checkpoint, map_location='cpu') if args.fuse_conv_bn: model = fuse_conv_bn(model) model = MMDataParallel(model, device_ids=[0]) model.eval() # the first several iterations may be very slow so skip them num_warmup = 5 pure_inf_time = 0 # benchmark with 2000 image and take the average for i, data in enumerate(data_loader): torch.cuda.synchronize() start_time = time.perf_counter() with torch.no_grad(): model(return_loss=False, rescale=True, **data) torch.cuda.synchronize() elapsed = time.perf_counter() - start_time if i >= num_warmup: pure_inf_time += elapsed if (i + 1) % args.log_interval == 0: fps = (i + 1 - num_warmup) / pure_inf_time print(f'Done image [{i + 1:<3}/ 2000], fps: {fps:.1f} img / s') if (i + 1) == 2000: pure_inf_time += elapsed fps = (i + 1 - num_warmup) / pure_inf_time print(f'Overall fps: {fps:.1f} img / s') break
def main(): args = parse_args() cfg = Config.fromfile(args.config) if cfg.get('USE_MMDET', False): from mmdet.apis import train_detector as train_model from mmtrack.models import build_detector as build_model if 'detector' in cfg.model: cfg.model = cfg.model.detector else: from mmtrack.apis import train_model from mmtrack.models import build_model if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) if args.resume_from is not None: cfg.resume_from = args.resume_from if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}.log') logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # init the meta dict to record some important information such as # environment info and seed, which will be logged meta = dict() # log env info env_info_dict = collect_env() env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) meta['env_info'] = env_info # log some basic info logger.info(f'Distributed training: {distributed}') logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds if args.seed is not None: logger.info(f'Set random seed to {args.seed}, ' f'deterministic: {args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) cfg.seed = args.seed meta['seed'] = args.seed if cfg.get('train_cfg', False): model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) else: model = build_model(cfg.model) datasets = [build_dataset(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) val_dataset.pipeline = cfg.data.train.pipeline datasets.append(build_dataset(val_dataset)) if cfg.checkpoint_config is not None: # save mmtrack version, config file content and class names in # checkpoints as meta data cfg.checkpoint_config.meta = dict(mmtrack_version=__version__, config=cfg.pretty_text, CLASSES=datasets[0].CLASSES) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES train_model(model, datasets, cfg, distributed=distributed, validate=(not args.no_validate), timestamp=timestamp, meta=meta)
def test_sot_forward(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) from mmtrack.models import build_model sot = build_model(model) # Test forward train with a non-empty truth batch input_shape = (1, 3, 127, 127) mm_inputs = _demo_mm_inputs(input_shape, num_items=[1]) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] search_input_shape = (1, 3, 255, 255) search_mm_inputs = _demo_mm_inputs(search_input_shape, num_items=[1]) search_img = search_mm_inputs.pop('imgs')[None] search_img_metas = search_mm_inputs.pop('img_metas') search_gt_bboxes = search_mm_inputs['gt_bboxes'] img_inds = search_gt_bboxes[0].new_full((search_gt_bboxes[0].size(0), 1), 0) search_gt_bboxes[0] = torch.cat((img_inds, search_gt_bboxes[0]), dim=1) losses = sot.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, search_img=search_img, search_img_metas=search_img_metas, search_gt_bboxes=search_gt_bboxes, is_positive_pairs=[True], return_loss=True) assert isinstance(losses, dict) loss, _ = sot._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() losses = sot.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, search_img=search_img, search_img_metas=search_img_metas, search_gt_bboxes=search_gt_bboxes, is_positive_pairs=[False], return_loss=True) assert isinstance(losses, dict) loss, _ = sot._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # Test forward test with torch.no_grad(): imgs = torch.cat([imgs, imgs.clone()], dim=0) img_list = [g[None, :] for g in imgs] img_metas.extend(copy.deepcopy(img_metas)) for i in range(len(img_metas)): img_metas[i]['frame_id'] = i gt_bboxes.extend(copy.deepcopy(gt_bboxes)) results = defaultdict(list) for one_img, one_meta, one_gt_bboxes in zip(img_list, img_metas, gt_bboxes): result = sot.forward([one_img], [[one_meta]], gt_bboxes=[one_gt_bboxes], return_loss=False) for k, v in result.items(): results[k].append(v)
def test_vis_forward(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) from mmtrack.models import build_model vis = build_model(model) # Test forward train with a non-empty truth batch input_shape = (1, 3, 256, 256) mm_inputs = _demo_mm_inputs(input_shape, num_items=[10], with_track=True) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] gt_labels = mm_inputs['gt_labels'] gt_instance_ids = mm_inputs['gt_instance_ids'] gt_masks = mm_inputs['gt_masks'] ref_input_shape = (1, 3, 256, 256) ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[11], with_track=True) ref_img = ref_mm_inputs.pop('imgs') ref_img_metas = ref_mm_inputs.pop('img_metas') ref_gt_bboxes = ref_mm_inputs['gt_bboxes'] ref_gt_labels = ref_mm_inputs['gt_labels'] ref_gt_masks = ref_mm_inputs['gt_masks'] ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids'] losses = vis.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, ref_img=ref_img, ref_img_metas=ref_img_metas, ref_gt_bboxes=ref_gt_bboxes, ref_gt_labels=ref_gt_labels, gt_instance_ids=gt_instance_ids, gt_masks=gt_masks, ref_gt_instance_ids=ref_gt_instance_ids, ref_gt_masks=ref_gt_masks, return_loss=True) assert isinstance(losses, dict) loss, _ = vis._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # Test forward train with an empty truth batch mm_inputs = _demo_mm_inputs(input_shape, num_items=[0], with_track=True) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] gt_labels = mm_inputs['gt_labels'] gt_instance_ids = mm_inputs['gt_instance_ids'] gt_masks = mm_inputs['gt_masks'] ref_input_shape = (1, 3, 256, 256) ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[0], with_track=True) ref_img = ref_mm_inputs.pop('imgs') ref_img_metas = ref_mm_inputs.pop('img_metas') ref_gt_bboxes = ref_mm_inputs['gt_bboxes'] ref_gt_labels = ref_mm_inputs['gt_labels'] ref_gt_masks = ref_mm_inputs['gt_masks'] ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids'] losses = vis.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, ref_img=ref_img, ref_img_metas=ref_img_metas, ref_gt_bboxes=ref_gt_bboxes, ref_gt_labels=ref_gt_labels, gt_instance_ids=gt_instance_ids, gt_masks=gt_masks, ref_gt_instance_ids=ref_gt_instance_ids, ref_gt_masks=ref_gt_masks, return_loss=True) assert isinstance(losses, dict) loss, _ = vis._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # Test forward test with torch.no_grad(): imgs = torch.cat([imgs, imgs.clone()], dim=0) img_list = [g[None, :] for g in imgs] img2_metas = copy.deepcopy(img_metas) img2_metas[0]['frame_id'] = 1 img_metas.extend(img2_metas) results = defaultdict(list) for one_img, one_meta in zip(img_list, img_metas): result = vis.forward([one_img], [[one_meta]], rescale=True, return_loss=False) for k, v in result.items(): results[k].append(v)
def test_vid_fgfa_style_forward(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) model.pretrains = None model.detector.pretrained = None from mmtrack.models import build_model detector = build_model(model) # Test forward train with a non-empty truth batch input_shape = (1, 3, 256, 256) mm_inputs = _demo_mm_inputs(input_shape, num_items=[10]) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') img_metas[0]['is_video_data'] = True gt_bboxes = mm_inputs['gt_bboxes'] gt_labels = mm_inputs['gt_labels'] gt_masks = mm_inputs['gt_masks'] ref_input_shape = (2, 3, 256, 256) ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[9, 11]) ref_img = ref_mm_inputs.pop('imgs')[None] ref_img_metas = ref_mm_inputs.pop('img_metas') ref_img_metas[0]['is_video_data'] = True ref_img_metas[1]['is_video_data'] = True ref_gt_bboxes = ref_mm_inputs['gt_bboxes'] ref_gt_labels = ref_mm_inputs['gt_labels'] ref_gt_masks = ref_mm_inputs['gt_masks'] losses = detector.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, ref_img=ref_img, ref_img_metas=[ref_img_metas], ref_gt_bboxes=ref_gt_bboxes, ref_gt_labels=ref_gt_labels, gt_masks=gt_masks, ref_gt_masks=ref_gt_masks, return_loss=True) assert isinstance(losses, dict) loss, _ = detector._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # Test forward train with an empty truth batch mm_inputs = _demo_mm_inputs(input_shape, num_items=[0]) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') img_metas[0]['is_video_data'] = True gt_bboxes = mm_inputs['gt_bboxes'] gt_labels = mm_inputs['gt_labels'] gt_masks = mm_inputs['gt_masks'] ref_mm_inputs = _demo_mm_inputs(ref_input_shape, num_items=[0, 0]) ref_imgs = ref_mm_inputs.pop('imgs')[None] ref_img_metas = ref_mm_inputs.pop('img_metas') ref_img_metas[0]['is_video_data'] = True ref_img_metas[1]['is_video_data'] = True ref_gt_bboxes = ref_mm_inputs['gt_bboxes'] ref_gt_labels = ref_mm_inputs['gt_labels'] ref_gt_masks = ref_mm_inputs['gt_masks'] losses = detector.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, ref_img=ref_imgs, ref_img_metas=[ref_img_metas], ref_gt_bboxes=ref_gt_bboxes, ref_gt_labels=ref_gt_labels, gt_masks=gt_masks, ref_gt_masks=ref_gt_masks, return_loss=True) assert isinstance(losses, dict) loss, _ = detector._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # Test forward test with frame_stride=1 and frame_range=[-1,0] with torch.no_grad(): imgs = torch.cat([imgs, imgs.clone()], dim=0) img_list = [g[None, :] for g in imgs] img_metas.extend(copy.deepcopy(img_metas)) for i in range(len(img_metas)): img_metas[i]['frame_id'] = i img_metas[i]['num_left_ref_imgs'] = 1 img_metas[i]['frame_stride'] = 1 ref_imgs = [ref_imgs.clone(), imgs[[0]][None].clone()] ref_img_metas = [ copy.deepcopy(ref_img_metas), copy.deepcopy([img_metas[0]]) ] results = defaultdict(list) for one_img, one_meta, ref_img, ref_img_meta in zip( img_list, img_metas, ref_imgs, ref_img_metas): result = detector.forward([one_img], [[one_meta]], ref_img=[ref_img], ref_img_metas=[[ref_img_meta]], return_loss=False) for k, v in result.items(): results[k].append(v)
def test_mot_forward_train(cfg_file): config = _get_config_module(cfg_file) model = copy.deepcopy(config.model) from mmtrack.models import build_model qdtrack = build_model(model) # Test forward train with a non-empty truth batch input_shape = (1, 3, 256, 256) mm_inputs = _demo_mm_inputs( input_shape, num_items=[10], num_classes=2, with_track=True) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] gt_labels = mm_inputs['gt_labels'] gt_instance_ids = mm_inputs['gt_instance_ids'] gt_masks = mm_inputs['gt_masks'] ref_input_shape = (1, 3, 256, 256) ref_mm_inputs = _demo_mm_inputs( ref_input_shape, num_items=[10], num_classes=2, with_track=True) ref_img = ref_mm_inputs.pop('imgs') ref_img_metas = ref_mm_inputs.pop('img_metas') ref_gt_bboxes = ref_mm_inputs['gt_bboxes'] ref_gt_labels = ref_mm_inputs['gt_labels'] ref_gt_masks = ref_mm_inputs['gt_masks'] ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids'] match_tool = MatchInstances() gt_match_indices, _ = match_tool._match_gts(gt_instance_ids[0], ref_gt_instance_ids[0]) gt_match_indices = [torch.tensor(gt_match_indices)] losses = qdtrack.forward( img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, gt_masks=gt_masks, gt_match_indices=gt_match_indices, ref_img=ref_img, ref_img_metas=ref_img_metas, ref_gt_bboxes=ref_gt_bboxes, ref_gt_labels=ref_gt_labels, ref_gt_masks=ref_gt_masks, return_loss=True) assert isinstance(losses, dict) loss, _ = qdtrack._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # Test forward train with an empty truth batch mm_inputs = _demo_mm_inputs( input_shape, num_items=[0], num_classes=2, with_track=True) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] gt_labels = mm_inputs['gt_labels'] gt_instance_ids = mm_inputs['gt_instance_ids'] gt_masks = mm_inputs['gt_masks'] ref_mm_inputs = _demo_mm_inputs( ref_input_shape, num_items=[0], num_classes=2, with_track=True) ref_img = ref_mm_inputs.pop('imgs') ref_img_metas = ref_mm_inputs.pop('img_metas') ref_gt_bboxes = ref_mm_inputs['gt_bboxes'] ref_gt_labels = ref_mm_inputs['gt_labels'] ref_gt_masks = ref_mm_inputs['gt_masks'] ref_gt_instance_ids = ref_mm_inputs['gt_instance_ids'] gt_match_indices, _ = match_tool._match_gts(gt_instance_ids[0], ref_gt_instance_ids[0]) gt_match_indices = [torch.tensor(gt_match_indices)] losses = qdtrack.forward( img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, gt_labels=gt_labels, gt_masks=gt_masks, gt_match_indices=gt_match_indices, ref_img=ref_img, ref_img_metas=ref_img_metas, ref_gt_bboxes=ref_gt_bboxes, ref_gt_labels=ref_gt_labels, ref_gt_masks=ref_gt_masks, return_loss=True) assert isinstance(losses, dict) loss, _ = qdtrack._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward()
def test_stark_forward(): # test stage-1 forward config = _get_config_module('sot/stark/stark_st1_r50_500e_got10k.py') model = copy.deepcopy(config.model) from mmtrack.models import build_model sot = build_model(model) # Test forward train with a non-empty truth batch input_shape = (2, 3, 128, 128) mm_inputs = _demo_mm_inputs(input_shape, num_items=[1, 1]) imgs = mm_inputs.pop('imgs')[None] img_metas = mm_inputs.pop('img_metas') gt_bboxes = mm_inputs['gt_bboxes'] padding_mask = torch.zeros((2, 128, 128), dtype=bool) padding_mask[0, 100:128, 100:128] = 1 padding_mask = padding_mask[None] search_input_shape = (1, 3, 320, 320) search_mm_inputs = _demo_mm_inputs(search_input_shape, num_items=[1]) search_img = search_mm_inputs.pop('imgs')[None] search_img_metas = search_mm_inputs.pop('img_metas') search_gt_bboxes = search_mm_inputs['gt_bboxes'] search_padding_mask = torch.zeros((1, 320, 320), dtype=bool) search_padding_mask[0, 0:20, 0:20] = 1 search_padding_mask = search_padding_mask[None] img_inds = search_gt_bboxes[0].new_full((search_gt_bboxes[0].size(0), 1), 0) search_gt_bboxes[0] = torch.cat((img_inds, search_gt_bboxes[0]), dim=1) losses = sot.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, padding_mask=padding_mask, search_img=search_img, search_img_metas=search_img_metas, search_gt_bboxes=search_gt_bboxes, search_padding_mask=search_padding_mask, return_loss=True) assert isinstance(losses, dict) assert losses['loss_bbox'] > 0 loss, _ = sot._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward() # test stage-2 forward config = _get_config_module('sot/stark/stark_st2_r50_50e_got10k.py') model = copy.deepcopy(config.model) sot = build_model(model) search_gt_labels = [torch.ones((1, 2))] losses = sot.forward(img=imgs, img_metas=img_metas, gt_bboxes=gt_bboxes, padding_mask=padding_mask, search_img=search_img, search_img_metas=search_img_metas, search_gt_bboxes=search_gt_bboxes, search_padding_mask=search_padding_mask, search_gt_labels=search_gt_labels, return_loss=True) assert isinstance(losses, dict) assert losses['loss_cls'] > 0 loss, _ = sot._parse_losses(losses) loss.requires_grad_(True) assert float(loss.item()) > 0 loss.backward()