コード例 #1
0
def main(_run, _config, _log):

    logdir = f'{_run.observers[0].dir}/'
    print(logdir)
    category = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]

    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    data_name = _config['dataset']
    max_label = 20 if data_name == 'VOC' else 80

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    print(_config['ckpt_dir'])
    tbwriter = SummaryWriter(osp.join(_config['ckpt_dir']))

    training_tags = {
        'loss': "ATraining/total_loss",
        "query_loss": "ATraining/query_loss",
        'aligned_loss': "ATraining/aligned_loss",
        'base_loss': "ATraining/base_loss",
    }
    infer_tags = {
        'mean_iou': "MeanIoU/mean_iou",
        "mean_iou_binary": "MeanIoU/mean_iou_binary",
    }

    _log.info('###### Create model ######')

    if _config['model']['part']:
        model = FewshotSegPartResnet(
            pretrained_path=_config['path']['init_path'], cfg=_config)
        _log.info('Model: FewshotSegPartResnet')

    else:
        model = FewshotSegResnet(pretrained_path=_config['path']['init_path'],
                                 cfg=_config)
        _log.info('Model: FewshotSegResnet')

    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    make_data = voc_fewshot
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'],
                        n_unlabel=_config['task']['n_unlabels'],
                        cfg=_config)
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=_config['num_workers'],
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    if _config['fix']:
        print('Optimizer: fix')
        optimizer = torch.optim.SGD(
            params=[
                {
                    "params": model.module.encoder.layer3.parameters(),
                    "lr": _config['optim']['lr'],
                    "weight_decay": _config['optim']['weight_decay']
                },
                {
                    "params": model.module.encoder.layer4.parameters(),
                    "lr": _config['optim']['lr'],
                    "weight_decay": _config['optim']['weight_decay']
                },
            ],
            momentum=_config['optim']['momentum'],
        )
    else:
        print('Optimizer: Not fix')
        optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    log_loss = {'loss': 0, 'align_loss': 0, 'base_loss': 0}
    _log.info('###### Training ######')

    highest_iou = 0
    metrics = {}

    for i_iter, sample_batched in enumerate(trainloader):
        if _config['fix']:
            model.module.encoder.conv1.eval()
            model.module.encoder.bn1.eval()
            model.module.encoder.layer1.eval()
            model.module.encoder.layer2.eval()

        if _config['eval']:
            if i_iter == 0:
                break
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)  #1*417*417

        base_loss = torch.zeros(1).to(torch.device('cuda'))
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, _, align_loss = model(support_images, support_fg_mask,
                                          support_bg_mask, query_images)
        query_loss = criterion(query_pred,
                               query_labels)  #1*3*417*417, 1*417*417
        loss = query_loss + align_loss * _config[
            'align_loss_scaler'] + base_loss * _config['base_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy()
        base_loss = base_loss.detach().data.cpu().numpy()

        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss
        log_loss['base_loss'] += base_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            base_loss = log_loss['base_loss'] / (i_iter + 1)

            print(
                f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}'
            )
            _log.info(
                f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}'
            )

            metrics['loss'] = loss
            metrics['query_loss'] = query_loss
            metrics['aligned_loss'] = align_loss
            metrics['base_loss'] = base_loss

            # for k, v in metrics.items():
            #     tbwriter.add_scalar(training_tags[k], v, i_iter)

        if (i_iter + 1) % _config['evaluate_interval'] == 0:
            _log.info('###### Evaluation begins ######')
            print(_config['ckpt_dir'])

            model.eval()

            labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
                _config['label_sets']]
            transforms = [Resize(size=_config['input_size'])]
            transforms = Compose(transforms)

            metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
            with torch.no_grad():
                for run in range(1):
                    _log.info(f'### Run {run + 1} ###')
                    set_seed(_config['seed'] + run)

                    _log.info(f'### Load data ###')
                    dataset = make_data(
                        base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['infer_max_iters'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'],
                        n_unlabel=_config['task']['n_unlabels'],
                        cfg=_config)
                    testloader = DataLoader(dataset,
                                            batch_size=_config['batch_size'],
                                            shuffle=False,
                                            num_workers=_config['num_workers'],
                                            pin_memory=True,
                                            drop_last=False)
                    _log.info(f"Total # of Data: {len(dataset)}")

                    for sample_batched in tqdm.tqdm(testloader):
                        label_ids = list(sample_batched['class_ids'])
                        support_images = [[
                            shot.cuda() for shot in way
                        ] for way in sample_batched['support_images']]
                        suffix = 'mask'
                        support_fg_mask = [[
                            shot[f'fg_{suffix}'].float().cuda() for shot in way
                        ] for way in sample_batched['support_mask']]
                        support_bg_mask = [[
                            shot[f'bg_{suffix}'].float().cuda() for shot in way
                        ] for way in sample_batched['support_mask']]

                        query_images = [
                            query_image.cuda()
                            for query_image in sample_batched['query_images']
                        ]
                        query_labels = torch.cat([
                            query_label.cuda()
                            for query_label in sample_batched['query_labels']
                        ],
                                                 dim=0)
                        query_pred, _, _ = model(support_images,
                                                 support_fg_mask,
                                                 support_bg_mask, query_images)
                        curr_iou = metric.record(query_pred.argmax(dim=1)[0],
                                                 query_labels[0],
                                                 labels=label_ids,
                                                 n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                        n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

                    _run.log_scalar('classIoU', classIoU.tolist())
                    _run.log_scalar('meanIoU', meanIoU.tolist())
                    _run.log_scalar('classIoU_binary',
                                    classIoU_binary.tolist())
                    _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
                    _log.info(f'classIoU: {classIoU}')
                    _log.info(f'meanIoU: {meanIoU}')
                    _log.info(f'classIoU_binary: {classIoU_binary}')
                    _log.info(f'meanIoU_binary: {meanIoU_binary}')

                    print(
                        f'meanIoU: {meanIoU}, meanIoU_binary: {meanIoU_binary}'
                    )

                    metrics = {}
                    metrics['mean_iou'] = meanIoU
                    metrics['mean_iou_binary'] = meanIoU_binary

                    for k, v in metrics.items():
                        tbwriter.add_scalar(infer_tags[k], v, i_iter)

                    if meanIoU > highest_iou:
                        print(
                            f'The highest iou is in iter: {i_iter} : {meanIoU}, save: {_config["ckpt_dir"]}/best.pth'
                        )
                        highest_iou = meanIoU
                        torch.save(
                            model.state_dict(),
                            os.path.join(f'{_config["ckpt_dir"]}/best.pth'))
                    else:
                        print(
                            f'The highest iou is in iter: {i_iter} : {meanIoU}'
                        )
            torch.save(model.state_dict(),
                       os.path.join(f'{_config["ckpt_dir"]}/{i_iter + 1}.pth'))
        model.train()

    print(_config['ckpt_dir'])

    _log.info(' --------- Testing begins ---------')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    transforms = Compose(transforms)
    ckpt = os.path.join(f'{_config["ckpt_dir"]}/best.pth')
    print(f'{_config["ckpt_dir"]}/best.pth')

    model.load_state_dict(torch.load(ckpt, map_location='cpu'))
    model.eval()

    metric = Metric(max_label=max_label, n_runs=5)
    with torch.no_grad():
        for run in range(5):
            n_iter = 0
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['infer_max_iters'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'],
                n_unlabel=_config['task']['n_unlabels'],
                cfg=_config)
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=_config['num_workers'],
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")
            for sample_batched in tqdm.tqdm(testloader):
                label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'mask'
                support_fg_mask = [[
                    shot[f'fg_{suffix}'].float().cuda() for shot in way
                ] for way in sample_batched['support_mask']]
                support_bg_mask = [[
                    shot[f'bg_{suffix}'].float().cuda() for shot in way
                ] for way in sample_batched['support_mask']]

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_labels = torch.cat([
                    query_label.cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)
                query_pred, _, _ = model(support_images, support_fg_mask,
                                         support_bg_mask, query_images)
                curr_iou = metric.record(query_pred.argmax(dim=1)[0],
                                         query_labels[0],
                                         labels=label_ids,
                                         n_run=run)
                n_iter += 1

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _run.log_scalar('meanIoU', meanIoU.tolist())
    _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())

    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())

    _log.info('----- Final Result -----')
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')

    _log.info("## ------------------------------------------ ##")
    _log.info(f'###### Setting: {_run.observers[0].dir} ######')

    _log.info(
        "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} "
        "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}".
        format(num_run=5,
               miou=meanIoU,
               mbiou=meanIoU_binary,
               miou_std=meanIoU_std,
               mbiou_std=meanIoU_std_binary))
    _log.info(f"Current setting is {_run.observers[0].dir}")

    print(
        "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} "
        "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}".
        format(num_run=5,
               miou=meanIoU,
               mbiou=meanIoU_binary,
               miou_std=meanIoU_std,
               mbiou_std=meanIoU_std_binary))
    print(f"Current setting is {_run.observers[0].dir}")
    print(_config['ckpt_dir'])
    print(logdir)
コード例 #2
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = Encoder(pretrained_path=_config['path']['init_path'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    if not _config['notrain']:
        model.load_state_dict(
            torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()

    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    if _config['scribble_dilation'] > 0:
        transforms.append(DilateScribble(size=_config['scribble_dilation']))
    transforms = Compose(transforms)

    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                label_sets=_config['label_sets'],
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'])
            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")

            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [
                        coco_cls_ids.index(x) + 1
                        for x in sample_batched['class_ids']
                    ]
                else:
                    label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                support_images = torch.cat(
                    [torch.cat(way, dim=0) for way in support_images], dim=0)
                suffix = 'scribble' if _config['scribble'] else 'mask'

                if _config['bbox']:
                    support_fg_mask = []
                    support_bg_mask = []
                    for i, way in enumerate(sample_batched['support_mask']):
                        fg_masks = []
                        bg_masks = []
                        for j, shot in enumerate(way):
                            fg_mask, bg_mask = get_bbox(
                                shot['fg_mask'],
                                sample_batched['support_inst'][i][j])
                            fg_masks.append(fg_mask.float().cuda())
                            bg_masks.append(bg_mask.float().cuda())
                        support_fg_mask.append(fg_masks)
                        support_bg_mask.append(bg_masks)
                else:
                    support_fg_mask = [[
                        shot[f'fg_mask'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]
                    support_fg_mask = torch.cat(
                        [torch.cat(way, dim=0) for way in support_fg_mask],
                        dim=0)

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_images = torch.cat(query_images, dim=0)

                query_labels = torch.cat([
                    query_label.long().cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)

                query_pred = model(support_images, query_images,
                                   support_fg_mask)
                query_pred = F.interpolate(query_pred,
                                           size=query_images.shape[-2:],
                                           mode='bilinear')

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids,
                              n_run=run)

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
コード例 #3
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_objectness = ModelBuilder.build_objectness(
        arch=cfg.MODEL.arch_objectness,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=2,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    crit = nn.NLLLoss(ignore_index=255)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized_objectness import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized_objectness import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    exclude_labels = labels_val

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = [
        transforms.ToNumpy(),
        transforms.RandScale([0.9, 1.1]),
        transforms.RandRotate([-10, 10], padding=mean, ignore_label=0),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]],
                        crop_type='rand',
                        padding=mean,
                        ignore_label=0)
    ]

    train_transform = Compose(train_transform)

    val_transform = Compose([
        transforms.ToNumpy(),
        transforms.Resize_pad(size=cfg.DATASET.input_size[0])
    ])

    dataset = make_data(base_dir=cfg.DATASET.data_dir,
                        split=cfg.DATASET.data_split,
                        transforms=train_transform,
                        to_tensor=transforms.ToTensorNormalize_noresize(),
                        labels=labels,
                        max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.TRAIN.permute_labels,
                        exclude_labels=exclude_labels)
    trainloader = DataLoader(dataset,
                             batch_size=cfg.TRAIN.n_batch,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=True)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    net_objectness.cuda()
    net_decoder.cuda()

    # Set up optimizers
    nets = (net_objectness, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    net_objectness.train(not cfg.TRAIN.fix_bn)
    net_decoder.train(not cfg.TRAIN.fix_bn)

    best_iou = 0
    # main loop
    tic = time.time()

    print('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        feed_dict = data_preprocess(sample_batched, cfg)

        data_time.update(time.time() - tic)
        net_objectness.zero_grad()
        net_decoder.zero_grad()

        # adjust learning rate
        adjust_learning_rate(optimizers, i_iter, cfg)

        # forward pass
        feat = net_objectness(feed_dict['img_data'], return_feature_maps=True)
        pred = net_decoder(feat)
        loss = crit(pred, feed_dict['seg_label'])
        acc = pixel_acc(pred, feed_dict['seg_label'])
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            if optimizer:
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        # calculate accuracy, and display
        if i_iter % cfg.TRAIN.disp_iter == 0:
            print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

            history['train']['iter'].append(i_iter)
            history['train']['loss'].append(loss.data.item())
            history['train']['acc'].append(acc.data.item())

        if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, i_iter + 1)

        if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
            metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
            with torch.no_grad():
                print('----Evaluation----')
                net_objectness.eval()
                net_decoder.eval()
                net_decoder.use_softmax = True
                for run in range(cfg.VAL.n_runs):
                    print(f'### Run {run + 1} ###')
                    set_seed(cfg.VAL.seed + run)

                    print(f'### Load validation data ###')
                    dataset_val = make_data(
                        base_dir=cfg.DATASET.data_dir,
                        split='val',
                        transforms=val_transform,
                        to_tensor=transforms.ToTensorNormalize_noresize(),
                        labels=labels_val,
                        max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.VAL.permute_labels,
                        exclude_labels=[])
                    if data_name == 'COCO':
                        coco_cls_ids = dataset_val.datasets[
                            0].dataset.coco.getCatIds()
                    testloader = DataLoader(dataset_val,
                                            batch_size=cfg.VAL.n_batch,
                                            shuffle=False,
                                            num_workers=1,
                                            pin_memory=True,
                                            drop_last=False)
                    print(f"Total # of validation Data: {len(dataset)}")

                    #for sample_batched in tqdm.tqdm(testloader):
                    for sample_batched in testloader:
                        feed_dict = data_preprocess(sample_batched,
                                                    cfg,
                                                    is_val=True)
                        if data_name == 'COCO':
                            label_ids = [
                                coco_cls_ids.index(x) + 1
                                for x in sample_batched['class_ids']
                            ]
                        else:
                            label_ids = list(sample_batched['class_ids'])

                        feat = net_objectness(feed_dict['img_data'],
                                              return_feature_maps=True)
                        query_pred = net_decoder(
                            feat, segSize=cfg.DATASET.input_size)
                        metric.record(
                            np.array(query_pred.argmax(dim=1)[0].cpu()),
                            np.array(feed_dict['seg_label'][0].cpu()),
                            labels=label_ids,
                            n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(
                        labels=sorted(labels_val), n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

            classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
                labels=sorted(labels_val))
            classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
            )

            print('----- Evaluation Result -----')
            print(f'best meanIoU_binary: {best_iou}')
            print(f'meanIoU mean: {meanIoU}')
            print(f'meanIoU std: {meanIoU_std}')
            print(f'meanIoU_binary mean: {meanIoU_binary}')
            print(f'meanIoU_binary std: {meanIoU_std_binary}')

            if meanIoU_binary > best_iou:
                best_iou = meanIoU_binary
                checkpoint(nets, history, cfg, 'best')
            net_objectness.train(not cfg.TRAIN.fix_bn)
            net_decoder.train(not cfg.TRAIN.fix_bn)
            net_decoder.use_softmax = False

    print('Training Done!')
コード例 #4
0
def main(cfg, gpus):
    torch.cuda.set_device(gpus[0])

    # Network Builders
    net_objectness = ModelBuilder.build_objectness(
        arch=cfg.MODEL.arch_objectness,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=2,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout,
        use_softmax=True)

    crit = nn.NLLLoss(ignore_index=255)

    net_objectness.cuda()
    net_objectness.eval()

    net_decoder.cuda()
    net_decoder.eval()

    print('###### Prepare data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        if cfg.VAL.test_with_classes:
            from dataloaders.customized import voc_fewshot
        else:
            from dataloaders.customized_objectness import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        if cfg.VAL.test_with_classes:
            from dataloaders.customized import coco_fewshot
        else:
            from dataloaders.customized_objectness import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
        split = cfg.DATASET.data_split + '2014'
        annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json'
        cocoapi = COCO(annFile)
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    #labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    #transforms = [Resize_test(size=cfg.DATASET.input_size)]
    val_transforms = [
        transforms.ToNumpy(),
        transforms.Resize_pad(size=cfg.DATASET.input_size[0])
    ]

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    '''val_transforms = [
        transforms.ToNumpy(),
        #transforms.RandScale([0.9, 1.1]),
        #transforms.RandRotate([-10, 10], padding=mean, ignore_label=0),
        #transforms.RandomGaussianBlur(),
        #transforms.RandomHorizontalFlip(),
        transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0)]'''

    val_transforms = Compose(val_transforms)

    print('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
    with torch.no_grad():
        for run in range(cfg.VAL.n_runs):
            print(f'### Run {run + 1} ###')
            set_seed(cfg.VAL.seed + run)

            print(f'### Load data ###')
            dataset = make_data(
                base_dir=cfg.DATASET.data_dir,
                split=cfg.DATASET.data_split,
                transforms=val_transforms,
                to_tensor=transforms.ToTensorNormalize_noresize(),
                labels=labels,
                max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch,
                n_ways=cfg.TASK.n_ways,
                n_shots=cfg.TASK.n_shots,
                n_queries=cfg.TASK.n_queries,
                permute=cfg.VAL.permute_labels,
            )
            if data_name == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset,
                                    batch_size=cfg.VAL.n_batch,
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)
            print(f"Total # of Data: {len(dataset)}")

            count = 0

            for sample_batched in tqdm.tqdm(testloader):
                feed_dict = data_preprocess(sample_batched, cfg)
                if data_name == 'COCO':
                    label_ids = [
                        coco_cls_ids.index(x) + 1
                        for x in sample_batched['class_ids']
                    ]
                else:
                    label_ids = list(sample_batched['class_ids'])

                feat = net_objectness(feed_dict['img_data'],
                                      return_feature_maps=True)
                query_pred = net_decoder(feat, segSize=(473, 473))

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(feed_dict['seg_label'][0].cpu()),
                              labels=label_ids,
                              n_run=run)

                if cfg.VAL.visualize:
                    #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape)
                    #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape)
                    #print(feed_dict['img_data'].cpu().shape)
                    query_name = sample_batched['query_ids'][0][0]
                    support_name = sample_batched['support_ids'][0][0][0]
                    if data_name == 'VOC':
                        img = imread(
                            os.path.join(cfg.DATASET.data_dir, 'JPEGImages',
                                         query_name + '.jpg'))
                    else:
                        query_name = int(query_name)
                        img_meta = cocoapi.loadImgs(query_name)[0]
                        img = imread(
                            os.path.join(cfg.DATASET.data_dir, split,
                                         img_meta['file_name']))
                    #img = imresize(img, cfg.DATASET.input_size)
                    visualize_result(
                        (img, as_numpy(feed_dict['seg_label'][0].cpu()),
                         '%05d' % (count)),
                        as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())),
                        os.path.join(cfg.DIR, 'result'))
                count += 1

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)
            '''_run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')'''

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    print('----- Final Result -----')
    print('final_classIoU', classIoU.tolist())
    print('final_classIoU_std', classIoU_std.tolist())
    print('final_meanIoU', meanIoU.tolist())
    print('final_meanIoU_std', meanIoU_std.tolist())
    print('final_classIoU_binary', classIoU_binary.tolist())
    print('final_classIoU_std_binary', classIoU_std_binary.tolist())
    print('final_meanIoU_binary', meanIoU_binary.tolist())
    print('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    print(f'classIoU mean: {classIoU}')
    print(f'classIoU std: {classIoU_std}')
    print(f'meanIoU mean: {meanIoU}')
    print(f'meanIoU std: {meanIoU_std}')
    print(f'classIoU_binary mean: {classIoU_binary}')
    print(f'classIoU_binary std: {classIoU_std_binary}')
    print(f'meanIoU_binary mean: {meanIoU_binary}')
    print(f'meanIoU_binary std: {meanIoU_std_binary}')
コード例 #5
0
def main(cfg, gpus):
    torch.cuda.set_device(gpus[0])

    # Network Builders
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_enc_memory = ModelBuilder.build_encoder_memory_separate(
        arch=cfg.MODEL.arch_memory_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.TASK.n_ways+1,
        RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate)    
    net_att_query = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.fc_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_projection = ModelBuilder.build_projection(
        arch=cfg.MODEL.arch_projection,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.projection_dim,
        weights=cfg.MODEL.weights_projection)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.decoder_fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=cfg.TASK.n_ways+1,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout,
        use_softmax=True)
    if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder:
        '''net_objectness = ModelBuilder.build_objectness(
            arch='resnet50_deeplab',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='aspp_few_shot',
            input_dim=2048,
            fc_dim=256,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            dropout_rate=0.5,
            use_dropout=True)'''
        net_objectness = ModelBuilder.build_objectness(
            arch='hrnetv2',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='c1_nodropout',
            input_dim=720,
            fc_dim=720,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            use_dropout=False)
        for param in net_objectness.parameters():
            param.requires_grad = False
        for param in net_objectness_decoder.parameters():
            param.requires_grad = False
    else:
        net_objectness = None
        net_objectness_decoder = None

    crit = nn.NLLLoss(ignore_index=255)

    segmentation_module = SegmentationAttentionSeparateModule(net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, debug=cfg.is_debug or cfg.eval_att_voting, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL.objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL.linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply)

    segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    segmentation_module.cuda()
    segmentation_module.eval()


    print('###### Prepare data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
        split = cfg.DATASET.data_split + '2014'
        annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json'
        cocoapi = COCO(annFile)
    else:
        raise ValueError('Wrong config for dataset!')
    #labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    transforms = [Resize_test(size=cfg.DATASET.input_size)]
    transforms = Compose(transforms)


    print('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
    with torch.no_grad():
        for run in range(cfg.VAL.n_runs):
            print(f'### Run {run + 1} ###')
            set_seed(cfg.VAL.seed + run)

            print(f'### Load data ###')
            dataset = make_data(
                base_dir=cfg.DATASET.data_dir,
                split=cfg.DATASET.data_split,
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch,
                n_ways=cfg.TASK.n_ways,
                n_shots=cfg.TASK.n_shots,
                n_queries=cfg.TASK.n_queries,
                permute=cfg.VAL.permute_labels,
                exclude_labels=[]
            )
            if data_name == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False,
                                    num_workers=1, pin_memory=True, drop_last=False)
            print(f"Total # of Data: {len(dataset)}")

            count = 0

            if cfg.multi_scale_test:
                scales = [224, 328, 424]
            else:
                scales = [328]

            for sample_batched in tqdm.tqdm(testloader):
                feed_dict = data_preprocess(sample_batched, cfg)
                if data_name == 'COCO':
                    label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']]
                else:
                    label_ids = list(sample_batched['class_ids'])

                for q, scale in enumerate(scales):
                    if len(scales) > 1:
                        feed_dict['img_data'] = nn.functional.interpolate(feed_dict['img_data'].cuda(), size=(scale, scale), mode='bilinear')
                    if cfg.eval_att_voting or cfg.is_debug:
                        query_pred, qread, qval, qk_b, mk_b, mv_b, p, feature_enc, feature_memory = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2]))
                        if cfg.eval_att_voting:
                            height, width = qread.shape[-2], qread.shape[-1]
                            assert p.shape[0] == height*width
                            img_refs_mask_resize = nn.functional.interpolate(feed_dict['img_refs_mask'][0].cuda(), size=(height, width), mode='nearest')
                            img_refs_mask_resize_flat = img_refs_mask_resize[:,0,:,:].view(img_refs_mask_resize.shape[0], -1)
                            mask_voting_flat = torch.mm(img_refs_mask_resize_flat, p)
                            mask_voting = mask_voting_flat.view(mask_voting_flat.shape[0], height, width)
                            mask_voting = torch.unsqueeze(mask_voting, 0)
                            query_pred = nn.functional.interpolate(mask_voting[:,0:-1], size=cfg.DATASET.input_size, mode='bilinear', align_corners=False)
                            if cfg.is_debug:
                                np.save('debug/img_refs_mask-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), img_refs_mask_resize.detach().cpu().float().numpy())
                                np.save('debug/query_pred-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), query_pred.detach().cpu().float().numpy())
                        if cfg.is_debug:
                            np.save('debug/qread-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qread.detach().cpu().float().numpy())
                            np.save('debug/qval-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qval.detach().cpu().float().numpy())
                            #np.save('debug/qk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qk_b.detach().cpu().float().numpy())
                            #np.save('debug/mk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mk_b.detach().cpu().float().numpy())
                            #np.save('debug/mv_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mv_b.detach().cpu().float().numpy())
                            #np.save('debug/p-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), p.detach().cpu().float().numpy())
                            #np.save('debug/feature_enc-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_enc[-1].detach().cpu().float().numpy())
                            #np.save('debug/feature_memory-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_memory[-1].detach().cpu().float().numpy())
                    else:
                        #query_pred = segmentation_module(feed_dict, segSize=cfg.DATASET.input_size)
                        query_pred = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2]))
                    if q == 0:
                        query_pred_final = query_pred/len(scales)
                    else:
                        query_pred_final += query_pred/len(scales)
                query_pred = query_pred_final
                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(feed_dict['seg_label_noresize'][0].cpu()),
                              labels=label_ids, n_run=run)

                if cfg.VAL.visualize:
                    #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape)
                    #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape)
                    #print(feed_dict['img_data'].cpu().shape)
                    query_name = sample_batched['query_ids'][0][0]
                    support_name = sample_batched['support_ids'][0][0][0]
                    if data_name == 'VOC':
                        img = imread(os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name+'.jpg'))
                    else:
                        query_name = int(query_name)
                        img_meta = cocoapi.loadImgs(query_name)[0]
                        img = imread(os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name']))
                    #img = imresize(img, cfg.DATASET.input_size)
                    visualize_result(
                        (img, as_numpy(feed_dict['seg_label_noresize'][0].cpu()), '%05d'%(count)),
                        as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())),
                        os.path.join(cfg.DIR, 'result')
                    )
                count += 1

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            '''_run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')'''

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary()

    print('----- Final Result -----')
    print('final_classIoU', classIoU.tolist())
    print('final_classIoU_std', classIoU_std.tolist())
    print('final_meanIoU', meanIoU.tolist())
    print('final_meanIoU_std', meanIoU_std.tolist())
    print('final_classIoU_binary', classIoU_binary.tolist())
    print('final_classIoU_std_binary', classIoU_std_binary.tolist())
    print('final_meanIoU_binary', meanIoU_binary.tolist())
    print('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    print(f'classIoU mean: {classIoU}')
    print(f'classIoU std: {classIoU_std}')
    print(f'meanIoU mean: {meanIoU}')
    print(f'meanIoU std: {meanIoU_std}')
    print(f'classIoU_binary mean: {classIoU_binary}')
    print(f'classIoU_binary std: {classIoU_std_binary}')
    print(f'meanIoU_binary mean: {meanIoU_binary}')
    print(f'meanIoU_binary std: {meanIoU_std_binary}')
コード例 #6
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_enc_memory = ModelBuilder.build_encoder_memory_separate(
        arch=cfg.MODEL.arch_memory_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.TASK.n_ways + 1,
        RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate)
    net_att_query = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.fc_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_projection = ModelBuilder.build_projection(
        arch=cfg.MODEL.arch_projection,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.projection_dim,
        weights=cfg.MODEL.weights_projection)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.decoder_fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=cfg.TASK.n_ways + 1,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder:
        '''net_objectness = ModelBuilder.build_objectness(
            arch='resnet50_deeplab',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='aspp_few_shot',
            input_dim=2048,
            fc_dim=256,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            dropout_rate=0.5,
            use_dropout=True)'''
        net_objectness = ModelBuilder.build_objectness(
            arch='hrnetv2',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='c1_nodropout',
            input_dim=720,
            fc_dim=720,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            use_dropout=False)
        for param in net_objectness.parameters():
            param.requires_grad = False
        for param in net_objectness_decoder.parameters():
            param.requires_grad = False
    else:
        net_objectness = None
        net_objectness_decoder = None

    crit = nn.NLLLoss(ignore_index=255)

    segmentation_module = SegmentationAttentionSeparateModule(
        net_enc_query,
        net_enc_memory,
        net_att_query,
        net_att_memory,
        net_decoder,
        net_projection,
        net_objectness,
        net_objectness_decoder,
        crit,
        zero_memory=cfg.MODEL.zero_memory,
        random_memory_bias=cfg.MODEL.random_memory_bias,
        random_memory_nobias=cfg.MODEL.random_memory_nobias,
        random_scale=cfg.MODEL.random_scale,
        zero_qval=cfg.MODEL.zero_qval,
        normalize_key=cfg.MODEL.normalize_key,
        p_scalar=cfg.MODEL.p_scalar,
        memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation,
        memory_noLabel=cfg.MODEL.memory_noLabel,
        mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate,
        att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate,
        objectness_feat_downsample_rate=cfg.MODEL.
        objectness_feat_downsample_rate,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate,
        mask_foreground=cfg.MODEL.mask_foreground,
        global_pool_read=cfg.MODEL.global_pool_read,
        average_memory_voting=cfg.MODEL.average_memory_voting,
        average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm,
        mask_memory_RGB=cfg.MODEL.mask_memory_RGB,
        linear_classifier_support=cfg.MODEL.linear_classifier_support,
        decay_lamb=cfg.MODEL.decay_lamb,
        linear_classifier_support_only=cfg.MODEL.
        linear_classifier_support_only,
        qread_only=cfg.MODEL.qread_only,
        feature_as_key=cfg.MODEL.feature_as_key,
        objectness_multiply=cfg.MODEL.objectness_multiply)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized_objectness_debug import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized_objectness_debug import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    if cfg.DATASET.exclude_labels:
        exclude_labels = labels_val
    else:
        exclude_labels = []
    transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()])
    dataset = make_data(base_dir=cfg.DATASET.data_dir,
                        split=cfg.DATASET.data_split,
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.TRAIN.permute_labels,
                        exclude_labels=exclude_labels,
                        use_ignore=cfg.use_ignore)
    trainloader = DataLoader(dataset,
                             batch_size=cfg.TRAIN.n_batch,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=True)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory,
            net_decoder, net_projection, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    segmentation_module.train(not cfg.TRAIN.fix_bn)
    if net_objectness and net_objectness_decoder:
        net_objectness.eval()
        net_objectness_decoder.eval()

    best_iou = 0
    # main loop
    tic = time.time()

    print('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        feed_dict = data_preprocess(sample_batched, cfg)

        data_time.update(time.time() - tic)
        segmentation_module.zero_grad()

        # adjust learning rate
        adjust_learning_rate(optimizers, i_iter, cfg)

        # forward pass
        #print(batch_data)
        loss, acc = segmentation_module(feed_dict)
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            if optimizer:
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        # calculate accuracy, and display
        if i_iter % cfg.TRAIN.disp_iter == 0:
            print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

            history['train']['iter'].append(i_iter)
            history['train']['loss'].append(loss.data.item())
            history['train']['acc'].append(acc.data.item())

        if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, i_iter + 1)

        if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
            metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
            with torch.no_grad():
                print('----Evaluation----')
                segmentation_module.eval()
                net_decoder.use_softmax = True
                for run in range(cfg.VAL.n_runs):
                    print(f'### Run {run + 1} ###')
                    set_seed(cfg.VAL.seed + run)

                    print(f'### Load validation data ###')
                    dataset_val = make_data(base_dir=cfg.DATASET.data_dir,
                                            split=cfg.DATASET.data_split,
                                            transforms=transforms,
                                            to_tensor=ToTensorNormalize(),
                                            labels=labels_val,
                                            max_iters=cfg.VAL.n_iters *
                                            cfg.VAL.n_batch,
                                            n_ways=cfg.TASK.n_ways,
                                            n_shots=cfg.TASK.n_shots,
                                            n_queries=cfg.TASK.n_queries,
                                            permute=cfg.VAL.permute_labels,
                                            exclude_labels=[])
                    if data_name == 'COCO':
                        coco_cls_ids = dataset_val.datasets[
                            0].dataset.coco.getCatIds()
                    testloader = DataLoader(dataset_val,
                                            batch_size=cfg.VAL.n_batch,
                                            shuffle=False,
                                            num_workers=1,
                                            pin_memory=True,
                                            drop_last=False)
                    print(f"Total # of validation Data: {len(dataset)}")

                    #for sample_batched in tqdm.tqdm(testloader):
                    for sample_batched in testloader:
                        feed_dict = data_preprocess(sample_batched,
                                                    cfg,
                                                    is_val=True)
                        if data_name == 'COCO':
                            label_ids = [
                                coco_cls_ids.index(x) + 1
                                for x in sample_batched['class_ids']
                            ]
                        else:
                            label_ids = list(sample_batched['class_ids'])

                        query_pred = segmentation_module(
                            feed_dict, segSize=cfg.DATASET.input_size)
                        metric.record(
                            np.array(query_pred.argmax(dim=1)[0].cpu()),
                            np.array(feed_dict['seg_label'][0].cpu()),
                            labels=label_ids,
                            n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(
                        labels=sorted(labels_val), n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

            classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
                labels=sorted(labels_val))
            classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
            )

            print('----- Evaluation Result -----')
            print(f'best meanIoU mean: {best_iou}')
            print(f'meanIoU mean: {meanIoU}')
            print(f'meanIoU std: {meanIoU_std}')
            print(f'meanIoU_binary mean: {meanIoU_binary}')
            print(f'meanIoU_binary std: {meanIoU_std_binary}')

            checkpoint(nets, history, cfg, 'latest')

            if meanIoU > best_iou:
                best_iou = meanIoU
                checkpoint(nets, history, cfg, 'best')
            segmentation_module.train(not cfg.TRAIN.fix_bn)
            if net_objectness and net_objectness_decoder:
                net_objectness.eval()
                net_objectness_decoder.eval()
            net_decoder.use_softmax = False

    print('Training Done!')
コード例 #7
0
def main(_run, _config, _log):
    os.makedirs(f'{_run.observers[0].dir}/features', exist_ok=True)
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                    exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)


    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'], task=_config['task'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    if not _config['notrain']:
        model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()


    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')

    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    transforms = Compose(transforms)


    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)
            features_dfs = []

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries']
            )

            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()

            testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False,
                                    num_workers=1, pin_memory=True, drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")


            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']]
                else:
                    label_ids = list(sample_batched['class_ids'])
                
                support_ids = [[sample_batched['support_ids'][way*_config['task']['n_shots'] + shot][0] 
                                for shot in range(_config['task']['n_shots'])]
                                for way in range(_config['task']['n_ways'])]
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]

                support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                                    for way in sample_batched['support_mask']]
                support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                                    for way in sample_batched['support_mask']]

                query_images = [query_image.cuda()
                                for query_image in sample_batched['query_images']]
                query_labels = torch.cat(
                    [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0)

                query_pred, _, supp_fts = model(support_images, support_fg_mask, support_bg_mask,
                                      query_images)

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids, n_run=run)
                
                # Save features row
                for i, label_id in enumerate(label_ids):
                    lbl_df = pd.DataFrame(torch.cat(supp_fts[i]).cpu().numpy())
                    lbl_df['label'] = label_id.item()
                    lbl_df['id'] = pd.Series(support_ids[i])
                    features_dfs.append(lbl_df)
                

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

            _log.info('### Exporting features CSV')
            features_df = pd.concat(features_dfs)
            features_df = features_df.drop_duplicates(subset=['id'])
            cols = list(features_df)
            cols = [cols[-1], cols[-2]] + cols[:-2]
            features_df = features_df[cols]
            features_df.to_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv', index=False)

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary()

    _log.info('###### Saving features visualization ######')
    all_fts = pd.concat([pd.read_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv') for run in range(_config['n_runs'])])
    all_fts = all_fts.drop_duplicates(subset=['id'])

    _log.info('### Obtaining Umap visualization ###')
    plot_umap(all_fts, f'{_run.observers[0].dir}/features/Umap_fts.png')

    _log.info('### Obtaining TSNE visualization ###')
    plot_tsne(all_fts, f'{_run.observers[0].dir}/features/TSNE_fts.png')


    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
コード例 #8
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info(f'###### Reload model {_config["reload_model_path"]} ######')
    model = FewShotSeg(pretrained_path=_config['reload_model_path'],
                       cfg=_config['model'])
    model = model.cuda()
    model.eval()

    _log.info('###### Load data ######')
    ### Training set
    data_name = _config['dataset']
    if data_name == 'SABS_Superpix':
        baseset_name = 'SABS'
        max_label = 13
    elif data_name == 'C0_Superpix':
        raise NotImplementedError
        baseset_name = 'C0'
        max_label = 3
    elif data_name == 'CHAOST2_Superpix':
        baseset_name = 'CHAOST2'
        max_label = 4
    else:
        raise ValueError(f'Dataset: {data_name} not found')

    test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP'][
        'pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][
            _config["label_sets"]]

    ### Transforms for data augmentation
    te_transforms = None

    assert _config[
        'scan_per_load'] < 0  # by default we load the entire dataset directly

    _log.info(
        f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######'
    )
    _log.info(
        f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######'
    )

    if baseset_name == 'SABS':  # for CT we need to know statistics of
        tr_parent = SuperpixelDataset(  # base dataset
            which_dataset=baseset_name,
            base_dir=_config['path'][data_name]['data_dir'],
            idx_split=_config['eval_fold'],
            mode='train',
            min_fg=str(
                _config["min_fg_data"]),  # dummy entry for superpixel dataset
            transforms=None,
            nsup=_config['task']['n_shots'],
            scan_per_load=_config['scan_per_load'],
            exclude_list=_config["exclude_cls_list"],
            superpix_scale=_config["superpix_scale"],
            fix_length=_config["max_iters_per_load"] if
            (data_name == 'C0_Superpix') or
            (data_name == 'CHAOST2_Superpix') else None)
        norm_func = tr_parent.norm_func
    else:
        norm_func = get_normalize_op(modality='MR', fids=None)

    te_dataset, te_parent = med_fewshot_val(
        dataset_name=baseset_name,
        base_dir=_config['path'][baseset_name]['data_dir'],
        idx_split=_config['eval_fold'],
        scan_per_load=_config['scan_per_load'],
        act_labels=test_labels,
        npart=_config['task']['npart'],
        nsup=_config['task']['n_shots'],
        extern_normalize_func=norm_func)

    ### dataloaders
    testloader = DataLoader(te_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=False,
                            drop_last=False)

    _log.info('###### Set validation nodes ######')
    mar_val_metric_node = Metric(
        max_label=max_label,
        n_scans=len(te_dataset.dataset.pid_curr_load) -
        _config['task']['n_shots'])

    _log.info('###### Starting validation ######')
    model.eval()
    mar_val_metric_node.reset()

    with torch.no_grad():
        save_pred_buffer = {}  # indexed by class

        for curr_lb in test_labels:
            te_dataset.set_curr_cls(curr_lb)
            support_batched = te_parent.get_support(
                curr_class=curr_lb,
                class_idx=[curr_lb],
                scan_idx=_config["support_idx"],
                npart=_config['task']['npart'])

            # way(1 for now) x part x shot x 3 x H x W] #
            support_images = [[shot.cuda() for shot in way]
                              for way in support_batched['support_images']
                              ]  # way x part x [shot x C x H x W]
            suffix = 'mask'
            support_fg_mask = [[
                shot[f'fg_{suffix}'].float().cuda() for shot in way
            ] for way in support_batched['support_mask']]
            support_bg_mask = [[
                shot[f'bg_{suffix}'].float().cuda() for shot in way
            ] for way in support_batched['support_mask']]

            curr_scan_count = -1  # counting for current scan
            _lb_buffer = {}  # indexed by scan

            last_qpart = 0  # used as indicator for adding result to buffer

            for sample_batched in testloader:

                _scan_id = sample_batched["scan_id"][
                    0]  # we assume batch size for query is 1
                if _scan_id in te_parent.potential_support_sid:  # skip the support scan, don't include that to query
                    continue
                if sample_batched["is_start"]:
                    ii = 0
                    curr_scan_count += 1
                    _scan_id = sample_batched["scan_id"][0]
                    outsize = te_dataset.dataset.info_by_scan[_scan_id][
                        "array_size"]
                    outsize = (
                        256, 256, outsize[0]
                    )  # original image read by itk: Z, H, W, in prediction we use H, W, Z
                    _pred = np.zeros(outsize)
                    _pred.fill(np.nan)

                q_part = sample_batched[
                    "part_assign"]  # the chunck of query, for assignment with support
                query_images = [sample_batched['image'].cuda()]
                query_labels = torch.cat([sample_batched['label'].cuda()],
                                         dim=0)

                # [way, [part, [shot x C x H x W]]] ->
                sup_img_part = [[
                    shot_tensor.unsqueeze(0)
                    for shot_tensor in support_images[0][q_part]
                ]]  # way(1) x shot x [B(1) x C x H x W]
                sup_fgm_part = [[
                    shot_tensor.unsqueeze(0)
                    for shot_tensor in support_fg_mask[0][q_part]
                ]]
                sup_bgm_part = [[
                    shot_tensor.unsqueeze(0)
                    for shot_tensor in support_bg_mask[0][q_part]
                ]]

                query_pred, _, _, assign_mats = model(
                    sup_img_part,
                    sup_fgm_part,
                    sup_bgm_part,
                    query_images,
                    isval=True,
                    val_wsize=_config["val_wsize"])

                query_pred = np.array(query_pred.argmax(dim=1)[0].cpu())
                _pred[..., ii] = query_pred.copy()

                if (sample_batched["z_id"] - sample_batched["z_max"] <=
                        _config['z_margin']) and (
                            sample_batched["z_id"] - sample_batched["z_min"] >=
                            -1 * _config['z_margin']):
                    mar_val_metric_node.record(query_pred,
                                               np.array(query_labels[0].cpu()),
                                               labels=[curr_lb],
                                               n_scan=curr_scan_count)
                else:
                    pass

                ii += 1
                # now check data format
                if sample_batched["is_end"]:
                    if _config['dataset'] != 'C0':
                        _lb_buffer[_scan_id] = _pred.transpose(
                            2, 0, 1)  # H, W, Z -> to Z H W
                    else:
                        lb_buffer[_scan_id] = _pred

            save_pred_buffer[str(curr_lb)] = _lb_buffer

        ### save results
        for curr_lb, _preds in save_pred_buffer.items():
            for _scan_id, _pred in _preds.items():
                _pred *= float(curr_lb)
                itk_pred = convert_to_sitk(
                    _pred, te_dataset.dataset.info_by_scan[_scan_id])
                fid = os.path.join(f'{_run.observers[0].dir}/interm_preds',
                                   f'scan_{_scan_id}_label_{curr_lb}.nii.gz')
                sitk.WriteImage(itk_pred, fid, True)
                _log.info(f'###### {fid} has been saved ######')

        del save_pred_buffer

    del sample_batched, support_images, support_bg_mask, query_images, query_labels, query_pred

    # compute dice scores by scan
    m_classDice, _, m_meanDice, _, m_rawDice = mar_val_metric_node.get_mDice(
        labels=sorted(test_labels), n_scan=None, give_raw=True)

    m_classPrec, _, m_meanPrec, _, m_classRec, _, m_meanRec, _, m_rawPrec, m_rawRec = mar_val_metric_node.get_mPrecRecall(
        labels=sorted(test_labels), n_scan=None, give_raw=True)

    mar_val_metric_node.reset()  # reset this calculation node

    # write validation result to log file
    _run.log_scalar('mar_val_batches_classDice', m_classDice.tolist())
    _run.log_scalar('mar_val_batches_meanDice', m_meanDice.tolist())
    _run.log_scalar('mar_val_batches_rawDice', m_rawDice.tolist())

    _run.log_scalar('mar_val_batches_classPrec', m_classPrec.tolist())
    _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec.tolist())
    _run.log_scalar('mar_val_batches_rawPrec', m_rawPrec.tolist())

    _run.log_scalar('mar_val_batches_classRec', m_classRec.tolist())
    _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec.tolist())
    _run.log_scalar('mar_val_al_batches_rawRec', m_rawRec.tolist())

    _log.info(f'mar_val batches classDice: {m_classDice}')
    _log.info(f'mar_val batches meanDice: {m_meanDice}')

    _log.info(f'mar_val batches classPrec: {m_classPrec}')
    _log.info(f'mar_val batches meanPrec: {m_meanPrec}')

    _log.info(f'mar_val batches classRec: {m_classRec}')
    _log.info(f'mar_val batches meanRec: {m_meanRec}')

    print("============ ============")

    _log.info(f'End of validation')
    return 1
コード例 #9
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    if not _config['notrain']:
        model.load_state_dict(
            torch.load(_config['snapshot'], map_location='cpu'))
    print("Snapshotttt")
    print(_config['snapshot'])
    model.eval()

    # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2net' + '/' + 'u2net_bce_itr_3168_train_1.523160_tar_0.203136.pth'
    # u2_net = U2NET(3,1)
    # u2_net.load_state_dict(torch.load(u2_model_dir))

    # if torch.cuda.is_available():
    #     u2_net.cuda()

    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    if _config['scribble_dilation'] > 0:
        transforms.append(DilateScribble(size=_config['scribble_dilation']))
    transforms = Compose(transforms)

    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'])
            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")

            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [
                        coco_cls_ids.index(x) + 1
                        for x in sample_batched['class_ids']
                    ]
                else:
                    label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'scribble' if _config['scribble'] else 'mask'

                if _config['bbox']:
                    support_fg_mask = []
                    support_bg_mask = []
                    for i, way in enumerate(sample_batched['support_mask']):
                        fg_masks = []
                        bg_masks = []
                        for j, shot in enumerate(way):
                            fg_mask, bg_mask = get_bbox(
                                shot['fg_mask'],
                                sample_batched['support_inst'][i][j])
                            fg_masks.append(fg_mask.float().cuda())
                            bg_masks.append(bg_mask.float().cuda())
                        support_fg_mask.append(fg_masks)
                        support_bg_mask.append(bg_masks)
                else:
                    support_fg_mask = [[
                        shot[f'fg_{suffix}'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]
                    support_bg_mask = [[
                        shot[f'bg_{suffix}'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_labels = torch.cat([
                    query_label.cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)

                # u2net
                inputs = query_images[0].type(torch.FloatTensor)
                labels_v = query_labels.type(torch.FloatTensor)
                if torch.cuda.is_available():
                    inputs_v, labels_v = Variable(
                        inputs.cuda(),
                        requires_grad=False), Variable(labels_v.cuda(),
                                                       requires_grad=False)
                else:
                    inputs_v, labels_v = Variable(
                        inputs,
                        requires_grad=False), Variable(labels_v,
                                                       requires_grad=False)
                #d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v)

                # normalization
                # pred = d1[:,0,:,:]
                # pred = normPRED(pred)
                pred = []

                query_pred, _, _ = model(support_images, support_fg_mask,
                                         support_bg_mask, query_images, pred)
                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids,
                              n_run=run)

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
コード例 #10
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_objectness = ModelBuilder.build_objectness(
        arch=cfg.MODEL.arch_objectness,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=2,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    crit = nn.NLLLoss(ignore_index=255)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        max_label = 20
    elif data_name == 'COCO':
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    exclude_labels = labels_val

    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = [
        transform.RandScale([0.9, 1.1]),
        transform.RandRotate([-10, 10], padding=mean, ignore_label=255),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=255),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ]
    train_transform = transform.Compose(train_transform)
    train_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir,
                                data_list=cfg.DATASET.train_list, transform=train_transform, mode='train', \
                                use_coco=False, use_split_coco=False)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=cfg.TRAIN.n_batch,
                                               shuffle=(train_sampler is None),
                                               num_workers=cfg.TRAIN.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_transform = transform.Compose([
        transform.Resize(size=cfg.DATASET.input_size[0]),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    val_data = dataset.SemData(split=cfg.TASK.fold_idx,
                               shot=cfg.TASK.n_shots,
                               data_root=cfg.DATASET.data_dir,
                               data_list=cfg.DATASET.val_list,
                               transform=val_transform,
                               mode='val',
                               use_coco=False,
                               use_split_coco=False)
    val_sampler = None
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=cfg.VAL.n_batch,
                                             shuffle=False,
                                             num_workers=cfg.TRAIN.workers,
                                             pin_memory=True,
                                             sampler=val_sampler)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    net_objectness.cuda()
    net_decoder.cuda()

    # Set up optimizers
    nets = (net_objectness, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    net_objectness.train(not cfg.TRAIN.fix_bn)
    net_decoder.train(not cfg.TRAIN.fix_bn)

    best_iou = 0
    # main loop
    tic = time.time()
    i_iter = -1
    print('###### Training ######')
    for epoch in range(0, 200):
        for _, (input, target) in enumerate(train_loader):
            # Prepare input
            i_iter += 1
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            data_time.update(time.time() - tic)
            net_objectness.zero_grad()
            net_decoder.zero_grad()

            # adjust learning rate
            adjust_learning_rate(optimizers, i_iter, cfg)

            # forward pass
            feat = net_objectness(input, return_feature_maps=True)
            pred = net_decoder(feat, segSize=cfg.DATASET.input_size)
            loss = crit(pred, target)
            acc = pixel_acc(pred, target)
            loss = loss.mean()
            acc = acc.mean()

            # Backward
            loss.backward()
            for optimizer in optimizers:
                if optimizer:
                    optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - tic)
            tic = time.time()

            # update average loss and acc
            ave_total_loss.update(loss.data.item())
            ave_acc.update(acc.data.item() * 100)

            # calculate accuracy, and display
            if i_iter % cfg.TRAIN.disp_iter == 0:
                print(
                    'Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                    'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                    'Ave_Accuracy: {:4.2f}, Accuracy:{:4.2f}, Ave_Loss: {:.6f}, Loss: {:.6f}'
                    .format(i_iter, i_iter, cfg.TRAIN.n_iters,
                            batch_time.average(), data_time.average(),
                            cfg.TRAIN.running_lr_encoder,
                            cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                            acc.data.item() * 100, ave_total_loss.average(),
                            loss.data.item()))

                history['train']['iter'].append(i_iter)
                history['train']['loss'].append(loss.data.item())
                history['train']['acc'].append(acc.data.item())

            if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
                checkpoint(nets, history, cfg, i_iter + 1)

            if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
                metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
                with torch.no_grad():
                    print('----Evaluation----')
                    net_objectness.eval()
                    net_decoder.eval()
                    net_decoder.use_softmax = True
                    #for run in range(cfg.VAL.n_runs):
                    for run in range(3):
                        print(f'### Run {run + 1} ###')
                        set_seed(cfg.VAL.seed + run)

                        print(f'### Load validation data ###')

                        #for sample_batched in tqdm.tqdm(testloader):
                        for (input, target, _) in val_loader:
                            input = input.cuda(non_blocking=True)
                            target = target.cuda(non_blocking=True)
                            feat = net_objectness(input,
                                                  return_feature_maps=True)
                            query_pred = net_decoder(
                                feat, segSize=cfg.DATASET.input_size)
                            metric.record(np.array(
                                query_pred.argmax(dim=1)[0].cpu()),
                                          np.array(target[0].cpu()),
                                          labels=None,
                                          n_run=run)

                classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
                )

                print('----- Evaluation Result -----')
                print(f'best meanIoU_binary: {best_iou}')
                print(f'meanIoU_binary mean: {meanIoU_binary}')
                print(f'meanIoU_binary std: {meanIoU_std_binary}')

                if meanIoU_binary > best_iou:
                    best_iou = meanIoU_binary
                    checkpoint(nets, history, cfg, 'best')
                net_objectness.train(not cfg.TRAIN.fix_bn)
                net_decoder.train(not cfg.TRAIN.fix_bn)
                net_decoder.use_softmax = False

    print('Training Done!')