コード例 #1
0
ファイル: ssm_train.py プロジェクト: wolfworld6/CALD
def train_one_epoch(task_model, task_optimizer, data_loader, device, cycle, epoch, print_freq):
    task_model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('task_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Cycle:[{}] Epoch: [{}]'.format(cycle, epoch)

    task_lr_scheduler = None

    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        task_lr_scheduler = utils.warmup_lr_scheduler(task_optimizer, warmup_iters, warmup_factor)
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        task_loss_dict = task_model(images, targets)
        task_losses = sum(loss for loss in task_loss_dict.values())
        # reduce losses over all GPUs for logging purposes
        task_loss_dict_reduced = utils.reduce_dict(task_loss_dict)
        task_losses_reduced = sum(loss.cpu() for loss in task_loss_dict_reduced.values())
        task_loss_value = task_losses_reduced.item()
        losses = task_losses
        if not math.isfinite(task_loss_value):
            print("Loss is {}, stopping training".format(task_loss_value))
            sys.exit(1)

        task_optimizer.zero_grad()
        losses.backward()
        task_optimizer.step()
        if task_lr_scheduler is not None:
            task_lr_scheduler.step()
        metric_logger.update(task_loss=task_losses_reduced)
        metric_logger.update(task_lr=task_optimizer.param_groups[0]["lr"])
    return metric_logger
コード例 #2
0
ファイル: train.py プロジェクト: isrc-cas/domain-adaption
def train_one_epoch(model,
                    optimizer,
                    data_loader,
                    device,
                    epoch,
                    print_freq=10,
                    writer=None):
    global global_step
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 500
        warmup_iters = min(500, len(data_loader) - 1)
        lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters,
                                           warmup_factor)

    for images, img_metas, targets in metric_logger.log_every(
            data_loader, print_freq, header):
        global_step += 1
        images = images.to(device)
        targets = [t.to(device) for t in targets]

        loss_dict, _ = model(images, img_metas, targets)
        losses = sum(list(loss_dict.values()))

        loss_dict_reduced = dist_utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()
        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        if global_step % print_freq == 0:
            if writer:
                for k, v in loss_dict_reduced.items():
                    writer.add_scalar('losses/{}'.format(k),
                                      v,
                                      global_step=global_step)
                writer.add_scalar('losses/total_loss',
                                  losses_reduced,
                                  global_step=global_step)
                writer.add_scalar('lr',
                                  optimizer.param_groups[0]['lr'],
                                  global_step=global_step)
コード例 #3
0
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters,
                                                 warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq,
                                                   header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced.detach().cpu(),
                             **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        del losses
        del loss_value
        del loss_dict_reduced
        del losses_reduced
        torch.cuda.empty_cache()
    return metric_logger
コード例 #4
0
def evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)
    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        outputs = [{k: v.to(cpu_device)
                    for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {
            target["image_id"].item(): output
            for target, output in zip(targets, outputs)
        }
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time,
                             evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    # 下面这句引起内存溢出
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    print("coco_evaluator.accumulate()")
    coco_evaluator.accumulate()
    print("coco_evaluator.summarize()")
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator
コード例 #5
0
def evaluate(model, data_loader, device):
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    results = []
    for images, targets in metric_logger.log_every(data_loader, 10, header):
        images = list(img.to(device) for img in images)

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(images)
        outputs = [{k: v.to(cpu_device)
                    for k, v in t.items()} for t in outputs]
        non_empt_el = [i for i, t in enumerate(targets) if len(t) > 0]
        targets = [t for i, t in enumerate(targets) if i in non_empt_el]
        outputs = [o for i, o in enumerate(outputs) if i in non_empt_el]
        for i, o in enumerate(outputs):
            o['boxes'][:, 2:4] -= o['boxes'][:, 0:2]
            areas = (o['boxes'][:, 2] * o['boxes'][:, 3]).tolist()
            boxes = o['boxes'].tolist()
            scores = o['scores'].tolist()
            labels = o['labels'].tolist()
            temp = [{
                'bbox': b,
                'area': a,
                'category_id': l,
                'score': s,
                'image_id': targets[i]['image_id']
            } for b, a, l, s in zip(boxes, areas, labels, scores)]
            results = list(itertools.chain(results, temp))
        evaluator_time = time.time()
        model_time = time.time() - model_time
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time,
                             evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    # accumulate predictions from all images
    return results
コード例 #6
0
def train_one_epoch(task_model, task_optimizer, ll_model, ll_optimizer,
                    data_loader, device, cycle, epoch, print_freq):
    task_model.train()
    ll_model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'task_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'll_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Cycle:[{}] Epoch: [{}]'.format(cycle, epoch)

    task_lr_scheduler = None
    ll_lr_scheduler = None

    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        task_lr_scheduler = utils.warmup_lr_scheduler(task_optimizer,
                                                      warmup_iters,
                                                      warmup_factor)
        ll_lr_scheduler = utils.warmup_lr_scheduler(ll_optimizer, warmup_iters,
                                                    warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq,
                                                   header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        features, task_loss_dict = task_model(images, targets)
        if 'faster' in args.model:
            _task_losses = sum(loss for loss in task_loss_dict.values())
            # print(_task_losses)
            task_loss_dict['loss_objectness'] = torch.mean(
                task_loss_dict['loss_objectness'])
            task_loss_dict['loss_rpn_box_reg'] = torch.mean(
                task_loss_dict['loss_rpn_box_reg'])
            task_loss_dict['loss_classifier'] = torch.mean(
                task_loss_dict['loss_classifier'])
            task_loss_dict['loss_box_reg'] = torch.mean(
                task_loss_dict['loss_box_reg'])
            task_losses = sum(loss for loss in task_loss_dict.values())
            # reduce losses over all GPUs for logging purposes
            task_loss_dict_reduced = utils.reduce_dict(task_loss_dict)
            task_losses_reduced = sum(
                loss.cpu() for loss in task_loss_dict_reduced.values())
            task_loss_value = task_losses_reduced.item()
            if epoch >= args.task_epochs:
                # After EPOCHL epochs, stop the gradient from the loss prediction module propagated to the target model.
                features['0'] = features['0'].detach()
                features['1'] = features['1'].detach()
                features['2'] = features['2'].detach()
                features['3'] = features['3'].detach()
            ll_pred = ll_model(features).cuda()
        elif 'retina' in args.model:
            _task_losses = sum(
                torch.stack(loss[1]) for loss in task_loss_dict.values())
            task_loss_dict['classification'] = task_loss_dict[
                'classification'][0]
            task_loss_dict['bbox_regression'] = task_loss_dict[
                'bbox_regression'][0]
            # for loss in task_loss_dict.values():
            #     print(loss)
            task_losses = sum(loss for loss in task_loss_dict.values())
            task_loss_dict_reduced = utils.reduce_dict(task_loss_dict)
            task_losses_reduced = sum(
                loss.cpu() for loss in task_loss_dict_reduced.values())
            task_loss_value = task_losses_reduced.item()
            if epoch >= args.task_epochs:
                # After EPOCHL epochs, stop the gradient from the loss prediction module propagated to the target model.
                _features = dict()
                _features['0'] = features[0].detach()
                _features['1'] = features[1].detach()
                _features['2'] = features[2].detach()
                _features['3'] = features[3].detach()
            else:
                _features = dict()
                _features['0'] = features[0]
                _features['1'] = features[1]
                _features['2'] = features[2]
                _features['3'] = features[3]
            ll_pred = ll_model(_features).cuda()
        ll_pred = ll_pred.view(ll_pred.size(0))
        ll_loss = args.ll_weight * LossPredLoss(
            ll_pred, _task_losses, margin=MARGIN)
        losses = task_losses + ll_loss
        if not math.isfinite(task_loss_value):
            print("Loss is {}, stopping training".format(task_loss_value))
            print(task_loss_dict_reduced)
            sys.exit(1)

        task_optimizer.zero_grad()
        ll_optimizer.zero_grad()
        losses.backward()
        task_optimizer.step()
        ll_optimizer.step()
        if task_lr_scheduler is not None:
            task_lr_scheduler.step()
        if ll_lr_scheduler is not None:
            ll_lr_scheduler.step()
        metric_logger.update(task_loss=task_losses_reduced)
        metric_logger.update(task_lr=task_optimizer.param_groups[0]["lr"])
        metric_logger.update(ll_loss=ll_loss.item())
        metric_logger.update(ll_lr=ll_optimizer.param_groups[0]["lr"])
    return metric_logger
コード例 #7
0
def train_one_epoch(task_model, task_optimizer, vae, vae_optimizer,
                    discriminator, discriminator_optimizer, labeled_dataloader,
                    unlabeled_dataloader, device, cycle, epoch, print_freq):
    def read_unlabeled_data(dataloader):
        while True:
            for images, _ in dataloader:
                yield list(image.to(device) for image in images)

    labeled_data = read_unlabeled_data(labeled_dataloader)
    unlabeled_data = read_unlabeled_data(unlabeled_dataloader)
    task_model.train()
    vae.train()
    discriminator.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'task_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Cycle:[{}] Epoch: [{}]'.format(cycle, epoch)

    task_lr_scheduler = None
    vae_lr_scheduler = None
    discriminator_lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(labeled_dataloader) - 1)

        task_lr_scheduler = utils.warmup_lr_scheduler(task_optimizer,
                                                      warmup_iters,
                                                      warmup_factor)
        vae_lr_scheduler = utils.warmup_lr_scheduler(vae_optimizer,
                                                     warmup_iters,
                                                     warmup_factor)
        discriminator_lr_scheduler = utils.warmup_lr_scheduler(
            discriminator_optimizer, warmup_iters, warmup_factor)

    for images, targets in metric_logger.log_every(labeled_dataloader,
                                                   print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        task_loss_dict = task_model(images, targets)
        task_losses = sum(loss for loss in task_loss_dict.values())
        # reduce losses over all GPUs for logging purposes
        task_loss_dict_reduced = utils.reduce_dict(task_loss_dict)
        task_losses_reduced = sum(loss.cpu()
                                  for loss in task_loss_dict_reduced.values())
        task_loss_value = task_losses_reduced.item()
        losses = task_losses
        if not math.isfinite(task_loss_value):
            print("Loss is {}, stopping training".format(task_loss_value))
            print(task_loss_dict_reduced)
            sys.exit(1)
        task_optimizer.zero_grad()
        losses.backward()
        task_optimizer.step()
        if task_lr_scheduler is not None:
            task_lr_scheduler.step()
        metric_logger.update(task_loss=task_losses_reduced)
        metric_logger.update(task_lr=task_optimizer.param_groups[0]["lr"])

    for i in range(len(labeled_dataloader)):
        unlabeled_imgs = next(unlabeled_data)
        labeled_imgs = next(labeled_data)
        recon, z, mu, logvar = vae(labeled_imgs)
        unsup_loss = vae_loss(labeled_imgs, recon, mu, logvar, 1)
        unlab_recon, unlab_z, unlab_mu, unlab_logvar = vae(unlabeled_imgs)
        transductive_loss = vae_loss(unlabeled_imgs, unlab_recon, unlab_mu,
                                     unlab_logvar, 1)

        labeled_preds = discriminator(mu)
        unlabeled_preds = discriminator(unlab_mu)

        lab_real_preds = torch.ones(len(labeled_imgs)).cuda()
        unlab_real_preds = torch.ones(len(unlabeled_imgs)).cuda()

        if not len(labeled_preds.shape) == len(lab_real_preds.shape):
            dsc_loss = bce_loss(
                labeled_preds, lab_real_preds.unsqueeze(1)) + bce_loss(
                    unlabeled_preds, unlab_real_preds.unsqueeze(1))
        else:
            dsc_loss = bce_loss(labeled_preds, lab_real_preds) + bce_loss(
                unlabeled_preds, unlab_real_preds)
        total_vae_loss = unsup_loss + transductive_loss + dsc_loss
        vae_optimizer.zero_grad()
        total_vae_loss.backward()
        vae_optimizer.step()

        # Discriminator step
        with torch.no_grad():
            _, _, mu, _ = vae(labeled_imgs)
            _, _, unlab_mu, _ = vae(unlabeled_imgs)

        labeled_preds = discriminator(mu)
        unlabeled_preds = discriminator(unlab_mu)

        lab_real_preds = torch.ones(len(labeled_imgs)).cuda()
        unlab_fake_preds = torch.zeros(len(unlabeled_imgs)).cuda()

        if not len(labeled_preds.shape) == len(lab_real_preds.shape):
            dsc_loss = bce_loss(
                labeled_preds, lab_real_preds.unsqueeze(1)) + bce_loss(
                    unlabeled_preds, unlab_fake_preds.unsqueeze(1))
        else:
            dsc_loss = bce_loss(labeled_preds, lab_real_preds) + bce_loss(
                unlabeled_preds, unlab_fake_preds)
        discriminator_optimizer.zero_grad()
        dsc_loss.backward()
        discriminator_optimizer.step()

        if vae_lr_scheduler is not None:
            vae_lr_scheduler.step()
        if discriminator_lr_scheduler is not None:
            discriminator_lr_scheduler.step()
        if i == len(labeled_dataloader) - 1:
            print('vae_loss: {} dis_loss:{}'.format(total_vae_loss, dsc_loss))

    return metric_logger
コード例 #8
0
def train_one_epoch(model,
                    optimizer,
                    data_loader,
                    device,
                    epoch,
                    print_freq=10):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters,
                                                 warmup_factor)
    writer = SummaryWriter(log_dir=log_dir)
    i = 0
    for images, targets in metric_logger.log_every(data_loader, print_freq,
                                                   header):
        if epoch == 27:
            print('')
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        i += 1
        print("epoch {}/70, iteration{}/{}".format(epoch, i, len(data_loader)))

        loss_dict = model(images, targets)

        # losses = sum(loss for loss in loss_dict.values())

        # # reduce losses over all GPUs for logging purposes
        # loss_dict_reduced = utils.reduce_dict(loss_dict)
        # losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        # loss_value = losses_reduced.item()

        # if not math.isfinite(loss_value):
        #     print("Loss is {}, stopping training".format(loss_value))
        #     print(loss_dict_reduced)
        #     sys.exit(1)

        # optimizer.zero_grad()
        # losses.backward()
        # optimizer.step()

        # if lr_scheduler is not None:
        #     lr_scheduler.step()
        # writer.add_scalars('train', {'loss_all': loss_value,
        #                              'loss_box_reg': loss_dict_reduced['loss_box_reg'].item(),
        #                              'loss_classifier': loss_dict_reduced['loss_classifier'].item(),
        #                              'loss_objectness': loss_dict_reduced['loss_objectness'].item(),
        #                              'loss_rpn_box_reg': loss_dict_reduced['loss_rpn_box_reg'].item()},
        #                    epoch*len(data_loader)+i)
        losses_reduced = torch.tensor([0.])
        loss_dict_reduced = {'1': 0.0}
        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
コード例 #9
0
def train_one_epoch(model,
                    optimizer,
                    train_loader,
                    target_loader,
                    device,
                    epoch,
                    dis_model,
                    dis_optimizer,
                    print_freq=10,
                    writer=None,
                    test_func=None,
                    save_func=None):
    global global_step
    model.train()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.0e}'))
    metric_logger.add_meter(
        'LAMBDA', utils.SmoothedValue(window_size=1, fmt='{value:.3f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_schedulers = []
    if epoch == 0:
        warmup_factor = 1. / 500
        warmup_iters = min(500, len(train_loader) - 1)
        # lr_schedulers = [
        #     warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor),
        #     warmup_lr_scheduler(dis_optimizer, warmup_iters, warmup_factor),
        # ]

    target_loader_iter = iter(target_loader)
    for images, img_metas, targets in metric_logger.log_every(
            train_loader, print_freq, header):
        global_step += 1
        images = images.to(device)
        targets = [t.to(device) for t in targets]

        try:
            t_images, t_img_metas, _ = next(target_loader_iter)
        except StopIteration:
            target_loader_iter = iter(target_loader)
            t_images, t_img_metas, _ = next(target_loader_iter)

        t_images = t_images.to(device)

        loss_dict, outputs = model(images, img_metas, targets, t_images,
                                   t_img_metas)
        adv_loss = loss_dict.pop('adv_loss')
        loss_dict_for_log = dict(**loss_dict, **adv_loss)

        det_loss = sum(list(loss_dict.values()))
        ada_loss = sum(list(adv_loss.values()))

        LAMBDA = cosine_scheduler(cfg.ADV.LAMBDA_FROM, cfg.ADV.LAMBDA_TO,
                                  global_step)
        losses = det_loss + ada_loss * LAMBDA
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        loss_dict_reduced = dist_utils.reduce_dict(loss_dict_for_log)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        for lr_scheduler in lr_schedulers:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(LAMBDA=LAMBDA)

        if global_step % print_freq == 0:
            if writer:
                for k, v in loss_dict_reduced.items():
                    writer.add_scalar('losses/{}'.format(k),
                                      v,
                                      global_step=global_step)
                writer.add_scalar('losses/total_loss',
                                  losses_reduced,
                                  global_step=global_step)
                writer.add_scalar('lr',
                                  optimizer.param_groups[0]['lr'],
                                  global_step=global_step)
                writer.add_scalar('LAMBDA', LAMBDA, global_step=global_step)

        if global_step % (2000 // max(1, (dist_utils.get_world_size() // 2))
                          ) == 0 and test_func is not None:
            updated = test_func()
            if updated:
                save_func('best.pth', 'mAP: {:.4f}'.format(best_mAP))
            print('Best mAP: {:.4f}'.format(best_mAP))
コード例 #10
0
ファイル: eval.py プロジェクト: isrc-cas/domain-adaption
def do_evaluation(model, data_loader, device, types, output_dir, iteration=None, viz=False):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    dataset = data_loader.dataset
    header = 'Testing {}:'.format(dataset.dataset_name)
    results_dict = {}
    has_mask = False
    for images, img_metas, targets in metric_logger.log_every(data_loader, 10, header):
        assert len(targets) == 1
        images = images.to(device)

        model_time = time.time()
        det = model(images, img_metas)[0]
        boxes, scores, labels = det['boxes'], det['scores'], det['labels']

        model_time = time.time() - model_time

        img_meta = img_metas[0]
        scale_factor = img_meta['scale_factor']
        img_info = img_meta['img_info']

        if viz:
            import matplotlib.pyplot as plt
            import matplotlib.patches as patches
            plt.switch_backend('TKAgg')
            image = de_normalize(images[0], img_meta)
            plt.subplot(122)
            plt.imshow(image)
            plt.title('Predict')
            for i, ((x1, y1, x2, y2), label) in enumerate(zip(boxes.tolist(), labels.tolist())):
                if scores[i] > 0.65:
                    rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, facecolor='none', edgecolor='g')
                    category_id = dataset.label2cat[label]
                    plt.text(x1, y1, '{}:{:.2f}'.format(dataset.CLASSES[category_id], scores[i]), color='r')
                    plt.gca().add_patch(rect)

            plt.subplot(121)
            plt.imshow(image)
            plt.title('GT')
            for i, ((x1, y1, x2, y2), label) in enumerate(zip(targets[0]['boxes'].tolist(), targets[0]['labels'].tolist())):
                category_id = dataset.label2cat[label]
                rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, facecolor='none', edgecolor='g')
                plt.text(x1, y1, '{}'.format(dataset.CLASSES[category_id]))
                plt.gca().add_patch(rect)
            plt.show()

        boxes /= scale_factor
        result = {}

        if 'masks' in det:
            has_mask = True
            (w, h) = img_meta['origin_img_shape']
            masks = paste_masks_in_image(det['masks'], boxes, (h, w))
            rles = []
            for mask in masks.cpu().numpy():
                mask = mask >= 0.5
                mask = mask_util.encode(np.array(mask[0][:, :, None], order='F', dtype='uint8'))[0]
                # "counts" is an array encoded by mask_util as a byte-stream. Python3's
                # json writer which always produces strings cannot serialize a bytestream
                # unless you decode it. Thankfully, utf-8 works out (which is also what
                # the pycocotools/_mask.pyx does).
                mask['counts'] = mask['counts'].decode('utf-8')
                rles.append(mask)
            result['masks'] = rles

        boxes = boxes.tolist()
        labels = labels.tolist()
        labels = [dataset.label2cat[label] for label in labels]
        scores = scores.tolist()

        result['boxes'] = boxes
        result['scores'] = scores
        result['labels'] = labels

        # save_visualization(dataset, img_meta, result, output_dir)

        results_dict.update({
            img_info['id']: result
        })
        metric_logger.update(model_time=model_time)

    if get_world_size() > 1:
        dist.barrier()

    predictions = _accumulate_predictions_from_multiple_gpus(results_dict)
    if not is_main_process():
        return {}
    results = {}
    if has_mask:
        result = coco_evaluation(dataset, predictions, output_dir, iteration=iteration)
        results.update(result)
    if 'voc' in types:
        result = voc_evaluation(dataset, predictions, output_dir, iteration=iteration, use_07_metric=False)
        results.update(result)
    return results