コード例 #1
0
def test_timer_init():
    timer = mmcv.Timer(start=False)
    assert not timer.is_running
    timer.start()
    assert timer.is_running
    timer = mmcv.Timer()
    assert timer.is_running
コード例 #2
0
def test_timer_context(capsys):
    with mmcv.Timer():
        time.sleep(1)
    out, _ = capsys.readouterr()
    assert abs(float(out) - 1) < 1e-2
    with mmcv.Timer(print_tmpl='time: {:.1f}s'):
        time.sleep(1)
    out, _ = capsys.readouterr()
    assert out == 'time: 1.0s\n'
コード例 #3
0
def test_timer_run():
    timer = mmcv.Timer()
    time.sleep(1)
    assert abs(timer.since_start() - 1) < 1e-2
    time.sleep(1)
    assert abs(timer.since_last_check() - 1) < 1e-2
    assert abs(timer.since_start() - 2) < 1e-2
    timer = mmcv.Timer(False)
    with pytest.raises(mmcv.TimerError):
        timer.since_start()
    with pytest.raises(mmcv.TimerError):
        timer.since_last_check()
コード例 #4
0
def evaluate(results, evals, topk=(1, 5, 10)):
    clips = accumulate_by_key(results, 'clip_ele_embed')
    syns = accumulate_by_key(results, 'syn_ele_embed')
    clip_len = accumulate_by_key(results, 'clip_ele_len')
    syn_len = accumulate_by_key(results, 'syn_ele_len')

    if 'basic' in evals:
        if clip_len is None:
            score_basic = clamp_cdist(clips, syns)
        else:
            score_basic = get_score_cosine_similarity(clips, syns,
                                                      clip_len.tolist(),
                                                      syn_len.tolist())
        stat = calc_stat(score_basic, topk, 'Basic')
        print_stat(stat)

    ## Eval optim
    for eval_method in evals:
        if eval_method == 'basic':
            continue
        with mmcv.Timer('Evaluating {}...'.format(eval_method)):
            seq_score = clamp_cdist(clips, syns)
            if eval_method == 'efm':
                title = 'EFW'
                score_optim, _, _ = get_score_efw(seq_score, clips, syns,
                                                  clip_len.tolist(),
                                                  syn_len.tolist())
            elif eval_method == 'bm':
                title = 'BM'
                score_optim, _, _ = get_score_bm(seq_score, clips, syns,
                                                 clip_len.tolist(),
                                                 syn_len.tolist())
            stat = calc_stat(score_optim, topk, title)
            print_stat(stat)
コード例 #5
0
def main():
    args = parse_args()
    for beta in [0.005, 0.01]:
        if beta == 0.005:
            cityscapes_path = '/home/wangyu/env/mmdetection_train/mmdetection/data/foggy_cityscapes_weak'
        elif beta == 0.01:
            cityscapes_path = '/home/wangyu/env/mmdetection_train/mmdetection/data/foggy_cityscapes_median'
        out_dir = osp.join(cityscapes_path, 'annotaions')
        mmcv.mkdir_or_exist(out_dir)

        img_dir = osp.join(cityscapes_path, args.img_dir)
        gt_dir = osp.join(cityscapes_path, args.gt_dir)

        set_name = dict(train='instancesonly_filtered_gtFine_train.json',
                        val='instancesonly_filtered_gtFine_val.json')

        for split, json_name in set_name.items():
            print(f'Converting {split} into {json_name}')
            with mmcv.Timer(
                    print_tmpl='It tooks {}s to convert Cityscapes annotation'
            ):
                files = collect_files(osp.join(img_dir, split),
                                      osp.join(gt_dir, split), beta)
                image_infos = collect_annotations(files, nproc=args.nproc)
                cvt_annotations(image_infos, osp.join(out_dir, json_name))
コード例 #6
0
def main():
    args = parse_args()
    root_path = args.root_path
    with mmcv.Timer(print_tmpl='It takes {}s to convert BID annotation'):
        files = collect_files(
            osp.join(root_path, 'imgs'), osp.join(root_path, 'annotations'))
        image_infos = collect_annotations(files, nproc=args.nproc)
        generate_ann(root_path, image_infos, args.preserve_vertical,
                     args.val_ratio, args.format)
コード例 #7
0
ファイル: imgur_converter.py プロジェクト: open-mmlab/mmocr
def main():
    args = parse_args()
    root_path = args.root_path

    for split in ['train', 'val', 'test']:
        print(f'Processing {split} set...')
        with mmcv.Timer(print_tmpl='It takes {}s to convert IMGUR annotation'):
            anno_infos = collect_imgur_info(
                root_path, f'imgur5k_annotations_{split}.json')
            generate_ann(root_path, split, anno_infos, args.format)
コード例 #8
0
ファイル: funsd_converter.py プロジェクト: HqWei/mmocr
def main():
    args = parse_args()
    root_path = args.root_path

    for split in ['training', 'test']:
        print(f'Processing {split} set...')
        with mmcv.Timer(print_tmpl='It takes {}s to convert FUNSD annotation'):
            files = collect_files(osp.join(root_path, 'imgs'),
                                  osp.join(root_path, 'annotations', split))
            image_infos = collect_annotations(files, nproc=args.nproc)
            generate_ann(root_path, split, image_infos, args.preserve_vertical,
                         args.format)
コード例 #9
0
ファイル: lv_converter.py プロジェクト: HqWei/mmocr
def main():
    args = parse_args()
    root_path = args.root_path

    for split in ['train', 'val', 'test']:
        print(f'Processing {split} set...')
        with mmcv.Timer(print_tmpl='It takes {}s to convert LV annotation'):
            files = collect_files(osp.join(root_path, 'imgs', split))
            image_infos = collect_annotations(files, nproc=args.nproc)
            convert_annotations(
                image_infos, osp.join(root_path,
                                      'instances_' + split + '.json'))
コード例 #10
0
def main():
    args = parse_args()
    root_path = args.root_path
    ratio = args.val_ratio

    trn_files, val_files = collect_files(
        osp.join(root_path, 'imgs'), osp.join(root_path, 'annotations'), ratio)

    # Train set
    trn_infos = collect_annotations(trn_files, nproc=args.nproc)
    with mmcv.Timer(
            print_tmpl='It takes {}s to convert KAIST Training annotation'):
        convert_annotations(trn_infos,
                            osp.join(root_path, 'instances_training.json'))

    # Val set
    if len(val_files) > 0:
        val_infos = collect_annotations(val_files, nproc=args.nproc)
        with mmcv.Timer(
                print_tmpl='It takes {}s to convert KAIST Val annotation'):
            convert_annotations(val_infos,
                                osp.join(root_path, 'instances_val.json'))
コード例 #11
0
ファイル: mtwi_converter.py プロジェクト: open-mmlab/mmocr
def main():
    args = parse_args()
    root_path = args.root_path
    ratio = args.val_ratio

    trn_files, val_files = collect_files(osp.join(root_path, 'imgs'),
                                         osp.join(root_path, 'annotations'),
                                         ratio)

    # Train set
    trn_infos = collect_annotations(trn_files, nproc=args.nproc)
    with mmcv.Timer(
            print_tmpl='It takes {}s to convert MTWI Training annotation'):
        generate_ann(root_path, 'training', trn_infos, args.preserve_vertical,
                     args.format)

    # Val set
    if len(val_files) > 0:
        val_infos = collect_annotations(val_files, nproc=args.nproc)
        with mmcv.Timer(
                print_tmpl='It takes {}s to convert MTWI Val annotation'):
            generate_ann(root_path, 'val', val_infos, args.preserve_vertical,
                         args.format)
コード例 #12
0
def main():
    args = parse_args()
    root_path = args.root_path
    split_info = mmcv.load(
        osp.join(root_path, 'annotations', 'train_valid_test_split.json'))
    split_info['training'] = split_info.pop('train')
    split_info['val'] = split_info.pop('valid')
    for split in ['training', 'val', 'test']:
        print(f'Processing {split} set...')
        with mmcv.Timer(print_tmpl='It takes {}s to convert NAF annotation'):
            files = collect_files(osp.join(root_path, 'imgs'),
                                  osp.join(root_path, 'annotations'),
                                  split_info[split])
            image_infos = collect_annotations(files, nproc=args.nproc)
            generate_ann(root_path, split, image_infos, args.preserve_vertical,
                         args.format)
コード例 #13
0
def main():
    args = parse_args()
    root_path = args.root_path
    with mmcv.Timer(print_tmpl='It takes {}s to convert BID annotation'):
        files = collect_files(osp.join(root_path, 'imgs'),
                              osp.join(root_path, 'annotations'))
        image_infos = collect_annotations(files, nproc=args.nproc)
        if args.val_ratio:
            image_infos = split_train_val_list(image_infos, args.val_ratio)
            splits = ['training', 'val']
        else:
            image_infos = [image_infos]
            splits = ['training']
        for i, split in enumerate(splits):
            convert_annotations(
                image_infos[i],
                osp.join(root_path, 'instances_' + split + '.json'))
コード例 #14
0
    def _f(*args, **kwargs):
        import mmcv
        timer = mmcv.Timer()
        import inspect
        from .utils import identify
        ident_name = identify((inspect.getsource(func), args, kwargs))
        if not ident_name in ICACHE:
            # logger.warning('{} not in CACHE: "{}"'.format(ident_name, ICACHE))
            result = func(*args, **kwargs)
            logger.info('Imemoize {}, init:  runtime: {:0.2f} s'.format(
                func.__name__, timer.since_last_check()))
            ICACHE[ident_name] = result
        else:
            result = ICACHE[ident_name]
            logger.info('Imemoize {} reuse, runtime: {:0.2f} s'.format(
                func.__name__, timer.since_last_check()))

        return result
コード例 #15
0
def main():
    args = parse_args()
    root_path = args.root_path
    img_dir = osp.join(root_path, 'imgs')
    gt_dir = osp.join(root_path, 'annotations')

    set_name = {}
    for split in ['training', 'test']:
        set_name.update({split: split + '_label' + '.txt'})
        assert osp.exists(osp.join(img_dir, split))

    for split, ann_name in set_name.items():
        print(f'Converting {split} into {ann_name}')
        with mmcv.Timer(
                print_tmpl='It takes {}s to convert totaltext annotation'):
            files = collect_files(osp.join(img_dir, split),
                                  osp.join(gt_dir, split))
            image_infos = collect_annotations(files, nproc=args.nproc)
            generate_ann(root_path, split, image_infos)
コード例 #16
0
ファイル: img2coco.py プロジェクト: HaoweiGis/Deeplearning
def main():
    args = parse_args()
    work_path = args.work_path
    out_dir = args.out_dir if args.out_dir else work_path
    mmcv.mkdir_or_exist(out_dir)

    set_name = dict(
        train='instancesonly_filtered_train.json',
        val='instancesonly_filtered_val.json',
        # test='instancesonly_filtered_test.json'
    )

    for split, json_name in set_name.items():
        print(f'Converting {split} into {json_name}')
        with mmcv.Timer(print_tmpl='It tooks {}s to convert coco annotation'):
            out_json = img2coco(osp.join(work_path, split))
            #     with open('json_name'.format(ROOT_DIR), 'w') as output_json_file:
            # json.dump(coco_output, output_json_file)
            mmcv.dump(out_json, osp.join(out_dir, json_name))
コード例 #17
0
def main():
    args = parse_args()
    cityscapes_path = args.cityscapes_path
    out_dir = args.out_dir if args.out_dir else cityscapes_path
    mmcv.mkdir_or_exist(out_dir)

    img_dir = osp.join(cityscapes_path, args.img_dir)
    gt_dir = osp.join(cityscapes_path, args.gt_dir)

    set_name = dict(train='instancesonly_filtered_gtFine_train.json',
                    val='instancesonly_filtered_gtFine_val.json',
                    test='instancesonly_filtered_gtFine_test.json')

    for split, json_name in set_name.items():
        print(f'Converting {split} into {json_name}')
        with mmcv.Timer(
                print_tmpl='It took {}s to convert Cityscapes annotation'):
            files = collect_files(osp.join(img_dir, split),
                                  osp.join(gt_dir, split))
            image_infos = collect_annotations(files, nproc=args.nproc)
            cvt_annotations(image_infos, osp.join(out_dir, json_name))
コード例 #18
0
ファイル: ctw1500_converter.py プロジェクト: xyzhu8/mmocr
def main():
    args = parse_args()
    root_path = args.root_path
    out_dir = args.out_dir if args.out_dir else root_path
    mmcv.mkdir_or_exist(out_dir)

    img_dir = osp.join(root_path, 'imgs')
    gt_dir = osp.join(root_path, 'annotations')

    set_name = {}
    for split in args.split_list:
        set_name.update({split: 'instances_' + split + '.json'})
        assert osp.exists(osp.join(img_dir, split))

    for split, json_name in set_name.items():
        print(f'Converting {split} into {json_name}')
        with mmcv.Timer(print_tmpl='It takes {}s to convert icdar annotation'):
            files = collect_files(osp.join(img_dir, split),
                                  osp.join(gt_dir, split), split)
            image_infos = collect_annotations(files, split, nproc=args.nproc)
            convert_annotations(image_infos, osp.join(out_dir, json_name))
def main(args):
    train_loader = get_loader(args)
    n_data = len(train_loader.dataset)
    logger.info("length of training dataset: {}".format(n_data))

    model, model_ema = build_model(args)
    logger.info('{}'.format(model))
    contrast = MemorySeCo(128, args.nce_k, args.nce_t, args.nce_t_intra).cuda()
    criterion = NCESoftmaxLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.batch_size * dist.get_world_size() /
                                256 * args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = get_scheduler(optimizer, len(train_loader), args)
    model = DistributedDataParallel(model,
                                    device_ids=[args.local_rank],
                                    broadcast_buffers=args.broadcast_buffer)
    logger.info('Distributed Enabled')

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        load_checkpoint(args, model, model_ema, contrast, optimizer, scheduler,
                        logger.info)

    # routine
    logger.info('Training')
    timer = mmcv.Timer()
    for epoch in range(args.start_epoch, args.epochs + 1):
        train_loader.sampler.set_epoch(epoch)
        loss = train_seco(epoch, train_loader, model, model_ema, contrast,
                          criterion, optimizer, scheduler, args)
        logger.info('epoch {}, total time {:.2f}, loss={}'.format(
            epoch, timer.since_last_check(), loss))
        if dist.get_rank() == 0:
            save_checkpoint(args, epoch, model, model_ema, contrast, optimizer,
                            scheduler, logger.info)
        dist.barrier()
コード例 #20
0
        print(task)

    # for i, task in enumerate(mmcv.track_iter_progress(tasks)):
    # do something like print
    # print(i)
    # print(task)

if flag_3:
    # 计时
    """
    It is convinient to compute the runtime of a code block with Timer.
    """
    import time
    import mmcv

    with mmcv.Timer():
        # simulate some code block
        # time.sleep(1)
        for _ in range(1000):
            a = 1
            b = 2
            c = a + b
            d = c ^ 9
            e = int(d / 888)
    """
    or try with since_start() and since_last_check(). 
    This former can return the runtime since the timer starts and the latter will return the time since the last time checked.
    """
    timer = mmcv.Timer()
    # code block 1 here
    time.sleep(2)
コード例 #21
0
    def after_step(self, current_step):
        # import ipdb; ipdb.set_trace()
        if current_step % self.trainer.config.TRAIN.eval_step == 0 and current_step!= 0:
            self.trainer.logger.info('Start clsuter the feature')
            frame_num = self.trainer.config.DATASET.train_clip_length
            frame_step = self.trainer.config.DATASET.train_clip_step
            feature_record = []
            for video_name in self.trainer.cluster_dataset_keys:
                dataset = self.trainer.cluster_dataset_dict[video_name]
                data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1)
                # import ipdb; ipdb.set_trace()
                for data in data_loader:
                    future = data[:, 2, :, :, :].cuda() # t+1 frame 
                    current = data[:, 1, :, :, :].cuda() # t frame
                    past = data[:, 0, :, :, :].cuda() # t frame
                    bboxs = self.trainer.get_batch_dets(current)
                    for index, bbox in enumerate(bboxs):
                        # import ipdb; ipdb.set_trace()
                        if bbox.numel() == 0:
                            # import ipdb; ipdb.set_trace()
                            # bbox = torch.zeros([1,4])
                            bbox = bbox.new_zeros([1,4])
                            # print('NO objects')
                            # continue
                        # import ipdb; ipdb.set_trace()
                        current_object, _ = multi_obj_grid_crop(current[index], bbox)
                        future_object, _ = multi_obj_grid_crop(future[index], bbox)
                        future2current = torch.stack([future_object, current_object], dim=1)

                        past_object, _ = multi_obj_grid_crop(past[index], bbox)
                        current2past = torch.stack([current_object, past_object], dim=1)

                        _, _, A_input = frame_gradient(future2current)
                        A_input = A_input.sum(1)
                        _, _, C_input = frame_gradient(current2past)
                        C_input = C_input.sum(1)
                        A_feature, _ = self.trainer.A(A_input)
                        B_feature, _ = self.trainer.B(current_object)
                        C_feature, _ = self.trainer.C(C_input)
                        
                        A_flatten_feature = A_feature.flatten(start_dim=1)
                        B_flatten_feature = B_feature.flatten(start_dim=1)
                        C_flatten_feature = C_feature.flatten(start_dim=1)
                        ABC_feature = torch.cat([A_flatten_feature, B_flatten_feature, C_flatten_feature], dim=1).detach()
                        # import ipdb; ipdb.set_trace()
                        ABC_feature_s = torch.chunk(ABC_feature, ABC_feature.size(0), dim=0)
                        # feature_record.extend(ABC_feature_s)
                        for abc_f in ABC_feature_s:
                            temp = abc_f.squeeze(0).cpu().numpy()
                            feature_record.append(temp)
                        # import ipdb; ipdb.set_trace()
                self.trainer.logger.info(f'Finish the video:{video_name}')
            self.trainer.logger.info(f'Finish extract feature, the sample:{len(feature_record)}')
            # model = KMeans(n_clusters=self.trainer.config.TRAIN.cluster.k)
            device = torch.device('cuda:0')
            cluster_input = torch.from_numpy(np.array(feature_record))
            time = mmcv.Timer()
            # import ipdb; ipdb.set_trace()
            cluster_centers = cluster_input.new_zeros(size=[self.trainer.config.TRAIN.cluster.k, 3072])
            for _ in range(10):
                cluster_ids_x, cluster_center = kmeans(X=cluster_input, num_clusters=self.trainer.config.TRAIN.cluster.k, distance='euclidean', device=device)
                cluster_centers += cluster_center
            import ipdb; ipdb.set_trace()
            cluster_centers =  cluster_centers / 10
            # model.fit(cluster_input)
            # pusedo_labels = model.predict(cluster_input)
            pusedo_labels = kmeans_predict(cluster_input, cluster_centers, 'euclidean', device=device).detach().cpu().numpy()
            print(f'The cluster time is :{time.since_start()/10} min')
            # import ipdb; ipdb.set_trace()
            # pusedo_labels = np.split(pusedo_labels, pusedo_labels.shape[0], 0)

            pusedo_dataset = os.path.join(self.trainer.config.TRAIN.pusedo_data_path, 'pusedo')
            if not os.path.exists(pusedo_dataset):
                os.mkdir(pusedo_dataset)
            
            np.savez_compressed(os.path.join(pusedo_dataset, f'{self.trainer.config.DATASET.name}_dummy.npz'), data=cluster_input, label=pusedo_labels)
            print(f'The save time is {time.since_last_check() / 60} min')
            # binary_labels = MultiLabelBinarizer().fit_transform(pusedo_labels)
            # self.trainer.ovr_model = OneVsRestClassifier(LinearSVC(random_state = 0)).fit(cluster_input,binary_labels)
            self.trainer.ovr_model = OneVsRestClassifier(LinearSVC(random_state = 0), n_jobs=16).fit(cluster_input, pusedo_labels)
            print(f'The train ovr: {time.since_last_check() / 60} min')
コード例 #22
0
print(test)

print('Gradcheck for carafe naive...')
test = gradcheck(CARAFENaive(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
print(test)

feat = torch.randn(2, 1024, 100, 100, requires_grad=True,
                   device='cuda:0').float()
mask = torch.randn(2, 25, 200, 200, requires_grad=True,
                   device='cuda:0').sigmoid().float()
loop_num = 500

time_forward = 0
time_backward = 0
bar = mmcv.ProgressBar(loop_num)
timer = mmcv.Timer()
for i in range(loop_num):
    x = carafe(feat.clone(), mask.clone(), 5, 1, 2)
    torch.cuda.synchronize()
    time_forward += timer.since_last_check()
    x.sum().backward(retain_graph=True)
    torch.cuda.synchronize()
    time_backward += timer.since_last_check()
    bar.update()
print('\nCARAFE time forward: {} ms/iter | time backward: {} ms/iter'.format(
    (time_forward + 1e-3) * 1e3 / loop_num,
    (time_backward + 1e-3) * 1e3 / loop_num))

time_naive_forward = 0
time_naive_backward = 0
bar = mmcv.ProgressBar(loop_num)
def train_seco(epoch, train_loader, model, model_ema, contrast, criterion,
               optimizer, scheduler, args):
    model.train()
    set_bn_train(model_ema)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    timer = mmcv.Timer()
    for idx, (xq, x1, x2, x3, binary_order) in enumerate(train_loader):
        xq = xq.cuda(non_blocking=True)  # query
        x1 = x1.cuda(non_blocking=True)  # same frame diff aug
        x2 = x2.cuda(non_blocking=True)  # diff frame 1
        x3 = x3.cuda(non_blocking=True)  # diff frame 2
        binary_order = binary_order.cuda(non_blocking=True)
        # forward keys
        with torch.no_grad():
            x1_shuffled, x1_backward_inds = DistributedShuffle.forward_shuffle(
                x1)
            x2_shuffled, x2_backward_inds = DistributedShuffle.forward_shuffle(
                x2)
            x3_shuffled, x3_backward_inds = DistributedShuffle.forward_shuffle(
                x3)
            x1_feat_inter, x1_feat_intra, x1_feat_order = model_ema(
                x1_shuffled)
            x2_feat_inter, x2_feat_intra, x2_feat_order = model_ema(
                x2_shuffled)
            x3_feat_inter, x3_feat_intra, x3_feat_order = model_ema(
                x3_shuffled)
            x1_feat_inter_all, x1_feat_inter = DistributedShuffle.backward_shuffle(
                x1_feat_inter, x1_backward_inds)
            x1_feat_intra_all, x1_feat_intra = DistributedShuffle.backward_shuffle(
                x1_feat_intra, x1_backward_inds)
            x2_feat_inter_all, x2_feat_inter = DistributedShuffle.backward_shuffle(
                x2_feat_inter, x2_backward_inds)
            x2_feat_intra_all, x2_feat_intra = DistributedShuffle.backward_shuffle(
                x2_feat_intra, x2_backward_inds)
            x2_feat_order_all, x2_feat_order = DistributedShuffle.backward_shuffle(
                x2_feat_order, x2_backward_inds)
            x3_feat_inter_all, x3_feat_inter = DistributedShuffle.backward_shuffle(
                x3_feat_inter, x3_backward_inds)
            x3_feat_intra_all, x3_feat_intra = DistributedShuffle.backward_shuffle(
                x3_feat_intra, x3_backward_inds)
            x3_feat_order_all, x3_feat_order = DistributedShuffle.backward_shuffle(
                x3_feat_order, x3_backward_inds)
        # forward query
        xq_feat_inter, xq_feat_intra, xq_feat_order, xq_logit_order = model(
            xq,
            order_feat=torch.cat(
                [x2_feat_order.detach(),
                 x3_feat_order.detach()], dim=1))
        out_inter = contrast(
            xq_feat_inter,
            x1_feat_inter,
            x2_feat_inter,
            x3_feat_inter,
            torch.cat(
                [x1_feat_inter_all, x2_feat_inter_all, x3_feat_inter_all],
                dim=0),
            inter=True)
        # loss calc
        out_intra = contrast(xq_feat_intra,
                             x1_feat_intra,
                             x2_feat_intra,
                             x3_feat_intra,
                             None,
                             inter=False)
        loss_inter = criterion(out_inter)
        loss_intra = criterion(out_intra)
        loss_order = torch.nn.functional.cross_entropy(xq_logit_order,
                                                       binary_order)
        loss = loss_inter + loss_intra + loss_order
        # backward
        optimizer.zero_grad()
        loss.backward()
        # update params
        optimizer.step()
        scheduler.step()
        moment_update(model, model_ema, args.alpha)
        # update meters
        loss_meter.update(loss.item())
        batch_time.update(timer.since_last_check())
        # print info
        if idx % args.print_freq == 0:
            logger.info(
                'Train: [{:>3d}]/[{:>4d}/{:>4d}] BT={:>0.3f}/{:>0.3f} Loss={:>0.3f} {:>0.3f} {:>0.3f} {:>0.3f}/{:>0.3f}'
                .format(
                    epoch,
                    idx,
                    len(train_loader),
                    batch_time.val,
                    batch_time.avg,
                    loss.item(),
                    loss_inter.item(),
                    loss_intra.item(),
                    loss_order.item(),
                    loss_meter.avg,
                ))
    return loss_meter.avg