def __init__(self,
                 backbone,
                 twin,
                 twin_load_from,
                 neck=None,
                 bbox_head=None,
                 pair_module=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(TwinV2SingleStageDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone)
        if neck is not None:
            self.neck = builder.build_neck(neck)
        self.neck_first = True
        if pair_module is not None:
            self.pair_module = builder.build_pair_module(
                pair_module)
            if hasattr(self.pair_module, 'neck_first'):
                self.neck_first = self.pair_module.neck_first
        self.bbox_head = builder.build_head(bbox_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained)

        # Build twin model
        print(" Loading Twin's weights...")
        self.twin = build_detector(twin, train_cfg=self.train_cfg, test_cfg=self.test_cfg)
        # self.twin = MMDataParallel(twin, device_ids=[0]).cuda()  # TODO check device id?
        load_checkpoint(self.twin, twin_load_from, map_location='cpu', strict=False, logger=None)
        print(" Finished loading twin's weight. ")
        self.twin.eval()

        # memory cache for testing
        self.key_feat = None
Beispiel #2
0
    def __init__(self,
                 det_load_from,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 tx_head=None,
                 temporal_module=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SeqTxSingleStage, self).__init__()
        self.backbone = builder.build_backbone(backbone)
        if neck is not None:
            self.neck = builder.build_neck(neck)
        assert temporal_module is None
        self.bbox_head = builder.build_head(bbox_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        # self.init_weights(pretrained=pretrained)
        load_checkpoint(self,
                        det_load_from,
                        map_location='cpu',
                        strict=False,
                        logger=None)
        self.backbone.eval()
        self.neck.eval()
        self.bbox_head.eval()

        self.tx_head = builder.build_head(tx_head)

        self.test_seq_len = 4
        self.memory = []
Beispiel #3
0
def test_load_classes_name():
    import os

    import tempfile

    from mmcv.runner import load_checkpoint, save_checkpoint
    checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
    model = Model()
    save_checkpoint(model, checkpoint_path)
    checkpoint = load_checkpoint(model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']

    model.CLASSES = ('class1', 'class2')
    save_checkpoint(model, checkpoint_path)
    checkpoint = load_checkpoint(model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
    assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')

    model = Model()
    wrapped_model = DDPWrapper(model)
    save_checkpoint(wrapped_model, checkpoint_path)
    checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']

    wrapped_model.module.CLASSES = ('class1', 'class2')
    save_checkpoint(wrapped_model, checkpoint_path)
    checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
    assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
    assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')

    # remove the temp file
    os.remove(checkpoint_path)
Beispiel #4
0
def test_load_checkpoint():
    import os
    import tempfile
    import re

    class PrefixModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.backbone = Model()

    pmodel = PrefixModel()
    model = Model()
    checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')

    # add prefix
    torch.save(model.state_dict(), checkpoint_path)
    state_dict = load_checkpoint(pmodel,
                                 checkpoint_path,
                                 revise_keys=[(r'^', 'backbone.')])
    for key in pmodel.backbone.state_dict().keys():
        assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
    # strip prefix
    torch.save(pmodel.state_dict(), checkpoint_path)
    state_dict = load_checkpoint(model,
                                 checkpoint_path,
                                 revise_keys=[(r'^backbone\.', '')])

    for key in state_dict.keys():
        key_stripped = re.sub(r'^backbone\.', '', key)
        assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
    os.remove(checkpoint_path)
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
                gpu_id, idx_queue, result_queue):
    model = model_cls(**model_kwargs)
    load_checkpoint(model, checkpoint, map_location='cpu')
    torch.cuda.set_device(gpu_id)
    model.cuda()
    model.eval()
    with torch.no_grad():
        while True:
            idx = idx_queue.get()
            data = dataset[idx]
            result = model(**data_func(data, gpu_id))
            result_queue.put((idx, result))
    def after_build_model(self, model):
        """Remove all pruned channels in finetune stage.

        We add this function to ensure that this happens before DDP's
        optimizer's initialization
        """

        if not self.pruning:
            for name, module in model.named_modules():
                add_pruning_attrs(module)
            assert self.deploy_from, 'You have to give a ckpt' \
                'containing the structure information of the pruning model'
            load_checkpoint(model, self.deploy_from)
            deploy_pruning(model)
Beispiel #7
0
def extract_feats(model,
                  datasets,
                  cfg,
                  save_dir,
                  data_split='train',
                  logger=None):

    load_checkpoint(model, cfg.load_from, 'cpu', False, logger)
    fg_th = cfg.train_cfg.rcnn.assigner.pos_iou_thr
    bg_th = cfg.train_cfg.rcnn.assigner.neg_iou_thr

    logger.info('load checkpoint from %s', cfg.load_from)
    logger.info(
        f'fg_iou_thr {fg_th} bg_iou_thr {bg_th} data_split {data_split} save_dir {save_dir}'
    )

    model.eval()
    model = model.cuda()
    data_loaders = [
        build_dataloader(ds,
                         cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu,
                         1,
                         dist=False) for ds in datasets
    ]

    feats = []
    labels = []
    for index, data in enumerate(data_loaders[0]):
        bbox_feats, bbox_labels, bboxes = model.feats_extract(
            data['img'].data[0], data['img_meta'].data[0],
            data['gt_bboxes'].data[0], data['gt_labels'].data[0])
        logger.info(
            f"{index:05}/{len(data_loaders[0])} feats shape - {bbox_feats.shape}"
        )

        feats.append(bbox_feats.data.cpu().numpy())
        labels.append(bbox_labels.data.cpu().numpy())
        del data, bbox_feats, bbox_labels, bboxes

    feats = np.concatenate(feats)
    labels = np.concatenate(labels)

    split = f'{fg_th}_{bg_th}'

    np.save(f'{save_dir}/{data_split}_{split}_feats.npy', feats)
    np.save(f'{save_dir}/{data_split}_{split}_labels.npy', labels)
    # import pdb; pdb.set_trace()
    print(f"{labels.shape} num of features")
Beispiel #8
0
    def __init__(self,
                 backbone,
                 midrange,
                 midrange_load_from,
                 middle='B',  # 'S'
                 neck=None,
                 bbox_head=None,
                 pair_module=None,
                 pair_module2=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(LMPSingleStageDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone)
        if neck is not None:
            self.neck = builder.build_neck(neck)
        if pair_module is not None:
            self.pair_module = builder.build_pair_module(
                pair_module)
        if pair_module2 is not None:
            self.pair_module2 = builder.build_pair_module(
                pair_module2)

        self.bbox_head = builder.build_head(bbox_head)
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained)

        # Build MidRange model
        print(" Loading MidRange's weights...")
        self.midrange = build_detector(
            midrange, train_cfg=self.train_cfg, test_cfg=self.test_cfg)
        load_checkpoint(self.midrange, midrange_load_from, map_location='cpu',
                        strict=False, logger=None)
        self.midrange.eval()
        self.middle = middle

        # memory cache for testing
        self.test_interval = test_cfg.get('test_interval', 10)
        self.memory_size = test_cfg.get('memory_size', 1)

        self.key_feat_pre = None  # This would not be used in simple_test_b
        self.key_feat_post = None
        # I'm sorry I have to be explicit
        self.g_ref_list_list = None
        self.phi_ref_list_list = None
    def __init__(self,
                 runner,
                 filename=None,
                 decay=0.9998,
                 out_dir=None,
                 interval=-1,
                 save_optimizer=True,
                 max_keep_ckpts=-1,
                 meta=None,
                 device='',
                 **kwargs):
        # setting checkpoints
        self.interval = interval
        self.out_dir = out_dir
        self.save_optimizer = save_optimizer
        self.max_keep_ckpts = max_keep_ckpts
        self.args = kwargs
        self.create_symlink = True
        self.filename = filename
        self.meta = meta

        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(runner.model)
        if self.filename:
            runner.logger.info('load EMA checkpoint for EMA model from %s',
                               filename)
            load_checkpoint(self.ema,
                            self.filename,
                            map_location='cpu',
                            strict=False)
        self.ema.eval()
        # self.updates = 0  # number of EMA updates
        self.updates = runner.iter  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(
            -x / 2000))  # decay exponential ramp (to help early epochs)
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        for p in self.ema.parameters():
            p.requires_grad_(False)

        runner.ema = self
Beispiel #10
0
    def load_checkpoint(self, filename, map_location='cpu', strict=True):
        self.logger.info('load checkpoint from %s', filename)

        if filename.startswith(('http://', 'https://')):
            url = filename
            filename = '../' + url.split('/')[-1]
            if get_dist_info()[0] == 0:
                if osp.isfile(filename):
                    os.system('rm ' + filename)
                os.system('wget -N -q -P ../ ' + url)
            dist.barrier()

        return load_checkpoint(self.model, filename, map_location, strict,
                               self.logger)
Beispiel #11
0
def prune_mask_rcnn_only(args: PruneParams):
    """
    Just prune without retraining.
    Args:
        args: (PruneParams).

    Returns: (MaskRCNN) pruned model in cuda.
    """
    cfg = mmcv.Config.fromfile(args.config)
    model = build_detector(cfg.model,
                           train_cfg=cfg.get('train_cfg'),
                           test_cfg=cfg.get('test_cfg'))
    assert cfg.model['type'] == 'MaskRCNN', 'model type should be MaskRCNN!'

    # load checkpoint
    checkpoint = cp.load_checkpoint(model=model, filename=args.checkpoint)

    num_before = sum([p.nelement() for p in model.backbone.parameters()])
    print('Before pruning, Backbone Params = %.2fM' % (num_before / 1E6))

    # PRUNE FILTERS
    # func = {"ResNet50": prune_resnet50, "ResNet101": prune_resnet101}
    assert args.backbone in ['ResNet50', 'ResNet101'], "Wrong backbone type!"
    skip = {
        'ResNet34': [2, 8, 14, 16, 26, 28, 30, 32],
        'ResNet50': [2, 11, 20, 23, 89, 92, 95, 98],
        'ResNet101': [2, 11, 20, 23, 89, 92, 95, 98]
    }
    pf_cfg, new_backbone = prune_top2_layers(arch=args.backbone,
                                             net=model.backbone,
                                             skip_block=skip[args.backbone],
                                             prs=str2list(args.prs),
                                             cuda=True)
    model.backbone = new_backbone

    num_after = sum([p.nelement() for p in model.backbone.parameters()])
    print('After  pruning: Backbone Params = %.2fM' % (num_after / 1E6))
    print("Prune rate: %.2f%%" % ((num_before - num_after) / num_before * 100))

    # replace checkpoint['state_dict']
    checkpoint['state_dict'] = cp.weights_to_cpu(cp.get_state_dict(model))
    mmcv.mkdir_or_exist(osp.dirname(args.result_path))

    # save and immediately flush buffer
    torch.save(checkpoint, args.result_path)
    with open(args.result_path.split('.')[0] + '_cfg.txt', 'w') as f:
        f.write(str(pf_cfg))
Beispiel #12
0
def worker_func_art(model_cls, model_kwargs, checkpoint, dataset, data_func,
                    gpu_id, idx_queue, result_queue, post_processor,
                    img_prefix, show=True, show_path=None):
    """ store the img_name, img_shape, ori_shape in data_queue
        return single_pred_results.
    """
    model = model_cls(**model_kwargs)
    load_checkpoint(model, checkpoint, map_location='cpu')
    torch.cuda.set_device(gpu_id)
    model.cuda()
    model.eval()
    with torch.no_grad():
        while True:
            idx = idx_queue.get()
            data = dataset[idx]
            data_dict = data_func(data, gpu_id)
            img_metas = data_dict['img_meta'][0]
            img_metas_0 = img_metas[0]
            is_aug = bool(len(data['img']) > 1)
            h, w, _ = img_metas_0['img_shape']
            ori_h, ori_w, _ = img_metas_0['ori_shape']
            filename = img_metas_0['filename']
            img_name = osp.splitext(filename)[0].replace('gt_', 'res_')
            scales = (ori_w * 1.0 / w, ori_h * 1.0 / h)
            result = model(**data_func(data, gpu_id))
            if isinstance(result, tuple):
                bbox_result, segm_result = result
            else:
                bbox_result, segm_result = result, None
            vs_bbox_result = np.vstack(bbox_result)
            if segm_result is None:
                pred_bboxes, pred_bbox_scores = [], []
            else:
                segm_scores = np.asarray(vs_bbox_result[:, -1])
                segms = mmcv.concat_list(segm_result)
                # the bboxes returned by processor are fit to the original images.
                if not is_aug:
                    # print('1-metas:{:d}'.format(len(img_metas)))
                    # single simple test the predicted mask use the size of rescaled img.
                    pred_bboxes, pred_bbox_scores = post_processor.process(segms, segm_scores,
                                                                           mask_shape=img_metas_0['img_shape'],
                                                                           scale_factor=scales)
                else:
                    # aug test the predicted mask use the size of original img
                    pred_bboxes, pred_bbox_scores = post_processor.process(segms, segm_scores,
                                                                           mask_shape=img_metas_0['ori_shape'],
                                                                           scale_factor=(1.0, 1.0))
            # save the results.
            single_pred_results = []
            for pred_bbox, pred_bbox_score in zip(pred_bboxes, pred_bbox_scores):
                pred_bbox = np.asarray(pred_bbox).reshape((-1, 2)).astype(np.int32)
                pred_bbox = pred_bbox.tolist()
                single_bbox_dict = {
                    "points": pred_bbox,
                    "confidence": float(pred_bbox_score),
                }
                single_pred_results.append(single_bbox_dict)
            pred_result = {
                "img_name": img_name,
                "single_pred_results": single_pred_results
            }
            result_queue.put((idx, pred_result))

            if show:
                img = cv2.imread(osp.join(img_prefix, filename))
                for idx in range(len(single_pred_results)):
                    bbox = np.asarray(single_pred_results[idx]["points"]).reshape(-1, 2).astype(np.int64)
                    cv2.drawContours(img, [bbox], -1, (0, 255, 0), 2)
                cv2.imwrite(osp.join(show_path, filename), img)
Beispiel #13
0
def test_load_checkpoint_metadata():
    import os
    import tempfile

    from mmcv.runner import load_checkpoint, save_checkpoint

    class ModelV1(nn.Module):
        def __init__(self):
            super().__init__()
            self.block = Block()
            self.conv1 = nn.Conv2d(3, 3, 1)
            self.conv2 = nn.Conv2d(3, 3, 1)
            nn.init.normal_(self.conv1.weight)
            nn.init.normal_(self.conv2.weight)

    class ModelV2(nn.Module):
        _version = 2

        def __init__(self):
            super().__init__()
            self.block = Block()
            self.conv0 = nn.Conv2d(3, 3, 1)
            self.conv1 = nn.Conv2d(3, 3, 1)
            nn.init.normal_(self.conv0.weight)
            nn.init.normal_(self.conv1.weight)

        def _load_from_state_dict(self, state_dict, prefix, local_metadata,
                                  *args, **kwargs):
            """load checkpoints."""

            # Names of some parameters in has been changed.
            version = local_metadata.get('version', None)
            if version is None or version < 2:
                state_dict_keys = list(state_dict.keys())
                convert_map = {'conv1': 'conv0', 'conv2': 'conv1'}
                for k in state_dict_keys:
                    for ori_str, new_str in convert_map.items():
                        if k.startswith(prefix + ori_str):
                            new_key = k.replace(ori_str, new_str)
                            state_dict[new_key] = state_dict[k]
                            del state_dict[k]

            super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                          *args, **kwargs)

    model_v1 = ModelV1()
    model_v1_conv0_weight = model_v1.conv1.weight.detach()
    model_v1_conv1_weight = model_v1.conv2.weight.detach()
    model_v2 = ModelV2()
    model_v2_conv0_weight = model_v2.conv0.weight.detach()
    model_v2_conv1_weight = model_v2.conv1.weight.detach()
    ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth')
    ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth')

    # Save checkpoint
    save_checkpoint(model_v1, ckpt_v1_path)
    save_checkpoint(model_v2, ckpt_v2_path)

    # test load v1 model
    load_checkpoint(model_v2, ckpt_v1_path)
    assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight)
    assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight)

    # test load v2 model
    load_checkpoint(model_v2, ckpt_v2_path)
    assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight)
    assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)
Beispiel #14
0
 def load_adv_checkpoint(self, filename, map_location='cpu', strict=False):
     self.logger.info('load checkpoint from %s', filename)
     return load_checkpoint(self.adv_model, filename, map_location, strict,
                            self.logger)
Beispiel #15
0
 def load_checkpoint(self, filename, map_location="cpu", strict=False):
     self.logger.info("load checkpoint from %s", filename)
     return load_checkpoint(self.model, filename, map_location, strict,
                            self.logger)
Beispiel #16
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.dir is not None:
        if args.dir.startswith('//'):
            cfg.work_dir = args.dir[2:]
        else:
            localhost = get_localhost().split('.')[0]
            # results from server saved to /private
            if 'gpu' in localhost:
                output_dir = '/private/huangchenxi/mmdet/outputs'
            else:
                output_dir = 'work_dirs'

            if args.dir.endswith('-c'):
                args.dir = args.dir[:-2]
                args.resume_from = search_and_delete(os.path.join(
                    output_dir, args.dir),
                                                     prefix=cfg.work_dir,
                                                     suffix=localhost)
            cfg.work_dir += time.strftime("_%m%d_%H%M") + '_' + localhost
            cfg.work_dir = os.path.join(output_dir, args.dir, cfg.work_dir)

    if args.workers_per_gpu != -1:
        cfg.data['workers_per_gpu'] = args.workers_per_gpu

    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus

    if args.profiler or args.speed:
        cfg.data.imgs_per_gpu = 1

    if cfg.resume_from or cfg.load_from:
        cfg.model['pretrained'] = None

    if args.test:
        cfg.data.train['ann_file'] = cfg.data.val['ann_file']
        cfg.data.train['img_prefix'] = cfg.data.val['img_prefix']

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
        num_gpus = args.gpus
        rank = 0
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
        num_gpus = torch.cuda.device_count()
        rank, _ = get_dist_info()

    if cfg.optimizer['type'] == 'SGD':
        cfg.optimizer['lr'] *= num_gpus * cfg.data.imgs_per_gpu / 256
    else:
        cfg.optimizer['lr'] *= ((num_gpus / 8) * (cfg.data.imgs_per_gpu / 2))

    # init logger before other steps
    logger = get_root_logger(nlogger, cfg.log_level)
    if rank == 0:
        logger.set_logger_dir(cfg.work_dir, 'd')
    logger.info("Config: ------------------------------------------\n" +
                cfg.text)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = build_detector(cfg.model,
                           train_cfg=cfg.train_cfg,
                           test_cfg=cfg.test_cfg)
    if rank == 0:
        # describe_vars(model)
        writer = set_writer(cfg.work_dir)
        # try:
        #     # describe_features(model.backbone)
        #     writer.add_graph(model, torch.zeros((1, 3, 800, 800)))
        # except (NotImplementedError, TypeError):
        #     logger.warn("Add graph failed.")
        # except Exception as e:
        #     logger.warn("Add graph failed:", e)

    if not args.graph and not args.profiler and not args.speed:
        if distributed:
            model = MMDistributedDataParallel(model.cuda())
        else:
            model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()

        if isinstance(cfg.data.train, list):
            for t in cfg.data.train:
                logger.info("loading training set: " + str(t.ann_file))
            train_dataset = [build_dataset(t) for t in cfg.data.train]
            CLASSES = train_dataset[0].CLASSES
        else:
            logger.info("loading training set: " +
                        str(cfg.data.train.ann_file))
            train_dataset = build_dataset(cfg.data.train)
            logger.info("{} images loaded!".format(len(train_dataset)))
            CLASSES = train_dataset.CLASSES
        if cfg.checkpoint_config is not None:
            # save mmdet version, config file content and class names in
            # checkpoints as meta data
            cfg.checkpoint_config.meta = dict(mmdet_version=__version__,
                                              config=cfg.text,
                                              CLASSES=CLASSES)
        # add an attribute for visualization convenience
        if hasattr(model, 'module'):
            model.module.CLASSES = CLASSES
        else:
            model.CLASSES = CLASSES
        train_detector(model,
                       train_dataset,
                       cfg,
                       distributed=distributed,
                       validate=args.validate,
                       logger=logger,
                       runner_attr_dict={'task_name': args.dir})
    else:
        from mmcv.runner.checkpoint import load_checkpoint
        from mmdet.datasets import build_dataloader
        from mmdet.core.utils.model_utils import register_hooks
        from mmdet.apis.train import parse_losses

        model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
        if args.profiler == 'test' or args.speed == 'test':
            model.eval()
            dataset = build_dataset(cfg.data.test)
        else:
            model.train()
            dataset = build_dataset(cfg.data.train)

        if cfg.load_from and (args.profiler or args.speed):
            logger.info('load checkpoint from %s', cfg.load_from)
            load_checkpoint(model,
                            cfg.load_from,
                            map_location='cpu',
                            strict=True)

        data_loader = build_dataloader(dataset,
                                       cfg.data.imgs_per_gpu,
                                       cfg.data.workers_per_gpu,
                                       cfg.gpus,
                                       dist=False,
                                       shuffle=False)

        if args.graph:
            id_dict = {}
            for name, parameter in model.named_parameters():
                id_dict[id(parameter)] = name

        for i, data_batch in enumerate(data_loader):
            if args.graph:
                outputs = model(**data_batch)
                loss, log_vars = parse_losses(outputs)
                get_dot = register_hooks(loss, id_dict)
                loss.backward()
                dot = get_dot()
                dot.save('graph.dot')
                break
            elif args.profiler:
                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    if args.profiler == 'train':
                        outputs = model(**data_batch)
                        loss, log_vars = parse_losses(outputs)
                        loss.backward()
                    else:
                        with torch.no_grad():
                            model(**data_batch, return_loss=False)

                    if i == 20:
                        prof.export_chrome_trace('./trace.json')
                        logger.info(prof)
                        break
            elif args.speed:
                if args.speed == 'train':
                    start = time.perf_counter()
                    outputs = model(**data_batch)
                    loss, log_vars = parse_losses(outputs)
                    loss.backward()
                    torch.cuda.synchronize()
                    end = time.perf_counter()
                else:
                    start = time.perf_counter()
                    with torch.no_grad():
                        model(**data_batch, return_loss=False)
                    end = time.perf_counter()
                logger.info("{:.3f} s/iter, {:.1f} iters/s".format(
                    end - start, 1. / (end - start)))