def main(_run, _config, _log):

    logdir = f'{_run.observers[0].dir}/'
    print(logdir)
    category = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]

    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    data_name = _config['dataset']
    max_label = 20 if data_name == 'VOC' else 80

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    print(_config['ckpt_dir'])
    tbwriter = SummaryWriter(osp.join(_config['ckpt_dir']))

    training_tags = {
        'loss': "ATraining/total_loss",
        "query_loss": "ATraining/query_loss",
        'aligned_loss': "ATraining/aligned_loss",
        'base_loss': "ATraining/base_loss",
    }
    infer_tags = {
        'mean_iou': "MeanIoU/mean_iou",
        "mean_iou_binary": "MeanIoU/mean_iou_binary",
    }

    _log.info('###### Create model ######')

    if _config['model']['part']:
        model = FewshotSegPartResnet(
            pretrained_path=_config['path']['init_path'], cfg=_config)
        _log.info('Model: FewshotSegPartResnet')

    else:
        model = FewshotSegResnet(pretrained_path=_config['path']['init_path'],
                                 cfg=_config)
        _log.info('Model: FewshotSegResnet')

    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    make_data = voc_fewshot
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'],
                        n_unlabel=_config['task']['n_unlabels'],
                        cfg=_config)
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=_config['num_workers'],
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    if _config['fix']:
        print('Optimizer: fix')
        optimizer = torch.optim.SGD(
            params=[
                {
                    "params": model.module.encoder.layer3.parameters(),
                    "lr": _config['optim']['lr'],
                    "weight_decay": _config['optim']['weight_decay']
                },
                {
                    "params": model.module.encoder.layer4.parameters(),
                    "lr": _config['optim']['lr'],
                    "weight_decay": _config['optim']['weight_decay']
                },
            ],
            momentum=_config['optim']['momentum'],
        )
    else:
        print('Optimizer: Not fix')
        optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    log_loss = {'loss': 0, 'align_loss': 0, 'base_loss': 0}
    _log.info('###### Training ######')

    highest_iou = 0
    metrics = {}

    for i_iter, sample_batched in enumerate(trainloader):
        if _config['fix']:
            model.module.encoder.conv1.eval()
            model.module.encoder.bn1.eval()
            model.module.encoder.layer1.eval()
            model.module.encoder.layer2.eval()

        if _config['eval']:
            if i_iter == 0:
                break
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)  #1*417*417

        base_loss = torch.zeros(1).to(torch.device('cuda'))
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, _, align_loss = model(support_images, support_fg_mask,
                                          support_bg_mask, query_images)
        query_loss = criterion(query_pred,
                               query_labels)  #1*3*417*417, 1*417*417
        loss = query_loss + align_loss * _config[
            'align_loss_scaler'] + base_loss * _config['base_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy()
        base_loss = base_loss.detach().data.cpu().numpy()

        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss
        log_loss['base_loss'] += base_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            base_loss = log_loss['base_loss'] / (i_iter + 1)

            print(
                f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}'
            )
            _log.info(
                f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}'
            )

            metrics['loss'] = loss
            metrics['query_loss'] = query_loss
            metrics['aligned_loss'] = align_loss
            metrics['base_loss'] = base_loss

            # for k, v in metrics.items():
            #     tbwriter.add_scalar(training_tags[k], v, i_iter)

        if (i_iter + 1) % _config['evaluate_interval'] == 0:
            _log.info('###### Evaluation begins ######')
            print(_config['ckpt_dir'])

            model.eval()

            labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
                _config['label_sets']]
            transforms = [Resize(size=_config['input_size'])]
            transforms = Compose(transforms)

            metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
            with torch.no_grad():
                for run in range(1):
                    _log.info(f'### Run {run + 1} ###')
                    set_seed(_config['seed'] + run)

                    _log.info(f'### Load data ###')
                    dataset = make_data(
                        base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['infer_max_iters'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'],
                        n_unlabel=_config['task']['n_unlabels'],
                        cfg=_config)
                    testloader = DataLoader(dataset,
                                            batch_size=_config['batch_size'],
                                            shuffle=False,
                                            num_workers=_config['num_workers'],
                                            pin_memory=True,
                                            drop_last=False)
                    _log.info(f"Total # of Data: {len(dataset)}")

                    for sample_batched in tqdm.tqdm(testloader):
                        label_ids = list(sample_batched['class_ids'])
                        support_images = [[
                            shot.cuda() for shot in way
                        ] for way in sample_batched['support_images']]
                        suffix = 'mask'
                        support_fg_mask = [[
                            shot[f'fg_{suffix}'].float().cuda() for shot in way
                        ] for way in sample_batched['support_mask']]
                        support_bg_mask = [[
                            shot[f'bg_{suffix}'].float().cuda() for shot in way
                        ] for way in sample_batched['support_mask']]

                        query_images = [
                            query_image.cuda()
                            for query_image in sample_batched['query_images']
                        ]
                        query_labels = torch.cat([
                            query_label.cuda()
                            for query_label in sample_batched['query_labels']
                        ],
                                                 dim=0)
                        query_pred, _, _ = model(support_images,
                                                 support_fg_mask,
                                                 support_bg_mask, query_images)
                        curr_iou = metric.record(query_pred.argmax(dim=1)[0],
                                                 query_labels[0],
                                                 labels=label_ids,
                                                 n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                        n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

                    _run.log_scalar('classIoU', classIoU.tolist())
                    _run.log_scalar('meanIoU', meanIoU.tolist())
                    _run.log_scalar('classIoU_binary',
                                    classIoU_binary.tolist())
                    _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
                    _log.info(f'classIoU: {classIoU}')
                    _log.info(f'meanIoU: {meanIoU}')
                    _log.info(f'classIoU_binary: {classIoU_binary}')
                    _log.info(f'meanIoU_binary: {meanIoU_binary}')

                    print(
                        f'meanIoU: {meanIoU}, meanIoU_binary: {meanIoU_binary}'
                    )

                    metrics = {}
                    metrics['mean_iou'] = meanIoU
                    metrics['mean_iou_binary'] = meanIoU_binary

                    for k, v in metrics.items():
                        tbwriter.add_scalar(infer_tags[k], v, i_iter)

                    if meanIoU > highest_iou:
                        print(
                            f'The highest iou is in iter: {i_iter} : {meanIoU}, save: {_config["ckpt_dir"]}/best.pth'
                        )
                        highest_iou = meanIoU
                        torch.save(
                            model.state_dict(),
                            os.path.join(f'{_config["ckpt_dir"]}/best.pth'))
                    else:
                        print(
                            f'The highest iou is in iter: {i_iter} : {meanIoU}'
                        )
            torch.save(model.state_dict(),
                       os.path.join(f'{_config["ckpt_dir"]}/{i_iter + 1}.pth'))
        model.train()

    print(_config['ckpt_dir'])

    _log.info(' --------- Testing begins ---------')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    transforms = Compose(transforms)
    ckpt = os.path.join(f'{_config["ckpt_dir"]}/best.pth')
    print(f'{_config["ckpt_dir"]}/best.pth')

    model.load_state_dict(torch.load(ckpt, map_location='cpu'))
    model.eval()

    metric = Metric(max_label=max_label, n_runs=5)
    with torch.no_grad():
        for run in range(5):
            n_iter = 0
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['infer_max_iters'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'],
                n_unlabel=_config['task']['n_unlabels'],
                cfg=_config)
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=_config['num_workers'],
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")
            for sample_batched in tqdm.tqdm(testloader):
                label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'mask'
                support_fg_mask = [[
                    shot[f'fg_{suffix}'].float().cuda() for shot in way
                ] for way in sample_batched['support_mask']]
                support_bg_mask = [[
                    shot[f'bg_{suffix}'].float().cuda() for shot in way
                ] for way in sample_batched['support_mask']]

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_labels = torch.cat([
                    query_label.cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)
                query_pred, _, _ = model(support_images, support_fg_mask,
                                         support_bg_mask, query_images)
                curr_iou = metric.record(query_pred.argmax(dim=1)[0],
                                         query_labels[0],
                                         labels=label_ids,
                                         n_run=run)
                n_iter += 1

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _run.log_scalar('meanIoU', meanIoU.tolist())
    _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())

    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())

    _log.info('----- Final Result -----')
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')

    _log.info("## ------------------------------------------ ##")
    _log.info(f'###### Setting: {_run.observers[0].dir} ######')

    _log.info(
        "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} "
        "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}".
        format(num_run=5,
               miou=meanIoU,
               mbiou=meanIoU_binary,
               miou_std=meanIoU_std,
               mbiou_std=meanIoU_std_binary))
    _log.info(f"Current setting is {_run.observers[0].dir}")

    print(
        "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} "
        "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}".
        format(num_run=5,
               miou=meanIoU,
               mbiou=meanIoU_binary,
               miou_std=meanIoU_std,
               mbiou_std=meanIoU_std_binary))
    print(f"Current setting is {_run.observers[0].dir}")
    print(_config['ckpt_dir'])
    print(logdir)
Beispiel #2
0
def main(config):
    num = 100000
    palette_path = config['palette_dir']
    with open(palette_path) as f:
        palette = f.readlines()
    palette = list(
        np.asarray([[int(p) for p in pal[0:-1].split(' ')]
                    for pal in palette]).reshape(768))

    n_shots = config['task']['n_shots']
    n_ways = config['task']['n_ways']
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=config['gpu_id'])
    torch.set_num_threads(1)
    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        config['gpu_id'],
    ])
    if config['train']:
        model.load_state_dict(
            torch.load(config['snapshots'], map_location='cpu'))
    model.eval()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_test
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['val']
    list_label = []
    for i in labels:
        list_label.append(i)
    list_label = sorted(list_label)
    transforms = [Resize(size=config['input_size'])]
    transforms = Compose(transforms)

    with torch.no_grad():
        for run in range(config['n_runs']):
            set_seed(config['seed'] + run)
            dataset = make_data(base_dir=config['path'][data_name]['data_dir'],
                                split=config['path'][data_name]['data_split'],
                                transforms=transforms,
                                to_tensor=ToTensorNormalize(),
                                labels=labels)

            testloader = DataLoader(dataset,
                                    batch_size=config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)

            for iteration, batch in enumerate(testloader):
                class_name = batch[2]
                if not os.path.exists(f'./result/{num}/{class_name[0]}'):
                    os.makedirs(f'./result/{num}/{class_name[0]}')
                num_frame = batch[1]
                all_class_data = batch[0]
                class_ids = all_class_data[0]['obj_ids']
                support_images = [[
                    all_class_data[0]['image'].cuda() for _ in range(n_shots)
                ] for _ in range(n_ways)]
                support_mask = all_class_data[0]['label'][
                    list_label[iteration]]
                support_fg_mask = [[
                    get_fg_mask(support_mask, class_ids[way])
                    for shot in range(n_shots)
                ] for way in range(len(class_ids))]
                support_bg_mask = [[
                    get_bg_mask(support_mask, class_ids)
                    for _ in range(n_shots)
                ] for _ in range(n_ways)]
                s_fg_mask = [[shot['fg_mask'].float().cuda() for shot in way]
                             for way in support_fg_mask]
                s_bg_mask = [[shot['bg_mask'].float().cuda() for shot in way]
                             for way in support_bg_mask]
                # print(f'fg_mask {s_bg_mask[0][0].shape}')
                # print(f'bg_mask {s_bg_mask[0][0].shape}')
                # print(support_mask.shape)

                for idx, data in enumerate(all_class_data):
                    query_images = [
                        all_class_data[idx]['image'].cuda()
                        for i in range(n_ways)
                    ]
                    query_labels = torch.cat([
                        query_label.cuda() for query_label in [
                            all_class_data[idx]['label'][
                                list_label[iteration]],
                        ]
                    ])

                    # print(f'query_image{query_images[0].shape}')
                    if idx > 0:
                        pre_mask = [
                            pred_mask,
                        ]
                    elif idx == 0:
                        pre_mask = [
                            support_mask.float().cuda(),
                        ]
                    query_pred, _ = model(support_images, s_fg_mask, s_bg_mask,
                                          query_images, pre_mask)
                    pred = query_pred.argmax(dim=1, keepdim=True)
                    pred = pred.data.cpu().numpy()
                    img = pred[0, 0]
                    img_e = Image.fromarray(img.astype('float32')).convert('P')
                    pred_mask = tr_F.resize(img_e,
                                            config['input_size'],
                                            interpolation=Image.NEAREST)
                    pred_mask = torch.Tensor(np.array(pred_mask))
                    pred_mask = torch.unsqueeze(pred_mask, dim=0)
                    pred_mask = pred_mask.float().cuda()
                    img_e.putpalette(palette)
                    # print(os.path.join(f'./result/{class_name[0]}/', '{:05d}.png'.format(idx)))
                    # print(batch[3][idx])
                    img_e.save(
                        os.path.join(f'./result/{num}/{class_name[0]}/',
                                     '{:05d}.png'.format(idx)))
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        os.makedirs(f'{_run.observers[0].dir}/val_logs', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(device=torch.device("cuda:2"),
                       n_grid=8,
                       overlap=True,
                       overlap_out='max')
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'Visceral':
        make_data = visceral_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    #labels=set([1,2])
    val_labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = Compose([
        Resize(size=_config['input_size']),
        #RandomAffine(),
        RandomBrightness(),
        RandomContrast(),
        RandomGamma()
    ])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)
    support_volumes = ['10000132_1_CTce_ThAb.nii.gz']
    query_volumes = [
        '10000100_1_CTce_ThAb.nii.gz', '10000104_1_CTce_ThAb.nii.gz',
        '10000105_1_CTce_ThAb.nii.gz', '10000106_1_CTce_ThAb.nii.gz',
        '10000108_1_CTce_ThAb.nii.gz', '10000109_1_CTce_ThAb.nii.gz',
        '10000110_1_CTce_ThAb.nii.gz', '10000111_1_CTce_ThAb.nii.gz',
        '10000112_1_CTce_ThAb.nii.gz', '10000113_1_CTce_ThAb.nii.gz',
        '10000127_1_CTce_ThAb.nii.gz', '10000128_1_CTce_ThAb.nii.gz',
        '10000129_1_CTce_ThAb.nii.gz', '10000130_1_CTce_ThAb.nii.gz',
        '10000131_1_CTce_ThAb.nii.gz', '10000133_1_CTce_ThAb.nii.gz',
        '10000134_1_CTce_ThAb.nii.gz', '10000135_1_CTce_ThAb.nii.gz',
        '10000136_1_CTce_ThAb.nii.gz'
    ]

    volumes_path = "/home/qinji/PANet_Visceral/eval_dataset/volumes/"
    segs_path = "/home/qinji/PANet_Visceral/eval_dataset/segmentations/"

    support_vol_dict, support_mask_dict = load_vol_and_mask(
        support_volumes, volumes_path, segs_path, labels=list(val_labels))
    query_vol_dict, query_mask_dict = load_vol_and_mask(
        query_volumes, volumes_path, segs_path, labels=list(val_labels))
    print('Successfully Load eval data!')
    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0}
    _log.info('###### Training ######')
    best_val_dice = 0
    model.eval()
    eval(model, -1, support_vol_dict, support_mask_dict, query_vol_dict,
         query_mask_dict, list(val_labels),
         os.path.join(f'{_run.observers[0].dir}/val_logs', 'val_logs.txt'))
    model.train()
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)

        # Forward and Backward
        optimizer.zero_grad()
        query_pred = model(support_images, support_fg_mask, support_bg_mask,
                           query_images)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        _run.log_scalar('loss', query_loss)
        log_loss['loss'] += query_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}')
            with open(
                    os.path.join(f'{_run.observers[0].dir}/val_logs',
                                 'val_logs.txt'), 'a') as f:
                f.write(f'step {i_iter+1}: loss: {loss}\n')

        if (i_iter + 1) % _config['val_interval'] == 0:
            _log.info('###### Validing ######')
            model.eval()
            average_labels_dice = eval(
                model, i_iter, support_vol_dict, support_mask_dict,
                query_vol_dict, query_mask_dict, list(val_labels),
                os.path.join(f'{_run.observers[0].dir}/val_logs',
                             'val_logs.txt'))
            model.train()
            print(f'step {i_iter+1}: val_average_dice: {average_labels_dice}')

            if average_labels_dice > best_val_dice:
                _log.info('###### Taking snapshot ######')
                best_val_dice = average_labels_dice
                torch.save(
                    model.state_dict(),
                    os.path.join(f'{_run.observers[0].dir}/snapshots',
                                 'best_val.pth'))

    _log.info('###### Saving final model ######')
    torch.save(
        model.state_dict(),
        os.path.join(f'{_run.observers[0].dir}/snapshots',
                     f'{i_iter + 1}.pth'))
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')


    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)


    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    model.train()

    # Using Saliency
    # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2netp' + '/' + 'u2netp.pth'
    # u2_net = U2NETP(3,1)
    # u2_net.load_state_dict(torch.load(u2_model_dir))
    
    # if torch.cuda.is_available():
    #     u2_net.cuda()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
    elif data_name == 'COCO':
        make_data = coco_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']),
                          RandomMirror()])
    dataset = make_data(
        base_dir=_config['path'][data_name]['data_dir'],
        split=_config['path'][data_name]['data_split'],
        transforms=transforms,
        to_tensor=ToTensorNormalize(),
        labels=labels,
        max_iters=_config['n_steps'] * _config['batch_size'],
        n_ways=_config['task']['n_ways'],
        n_shots=_config['task']['n_shots'],
        n_queries=_config['task']['n_queries']
    )
    trainloader = DataLoader(
        dataset,
        batch_size=_config['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True
    )

    _log.info('###### Set optimizer ######')
    print(_config['mode'])
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0, 'dist_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [query_image.cuda()
                        for query_image in sample_batched['query_images']]
        query_labels = torch.cat(
            [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0)
        # Forward and Backward
        optimizer.zero_grad()
        
        # with torch.no_grad():
        #   # u2net
        #   inputs = query_images[0].type(torch.FloatTensor)
        #   labels = query_labels.type(torch.FloatTensor)
        #   if torch.cuda.is_available():
        #       inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
        #                                                                                   requires_grad=False)
        #   else:
        #       inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
        #   d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v)

        #   # normalization
        #   pred = d1[:,0,:,:]
        #   pred = normPRED(pred)
        pred = []

        query_pred, align_loss, dist_loss = model(support_images, support_fg_mask, support_bg_mask,
                                       query_images, pred)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + dist_loss + align_loss * 0.2 #_config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        dist_loss = dist_loss.detach().data.cpu().numpy() if dist_loss != 0 else 0
        align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        _run.log_scalar('dist_loss', dist_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss
        log_loss['dist_loss'] += dist_loss



        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, dist_loss: {dist_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(model.state_dict(),
                       os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(model.state_dict(),
               os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
Beispiel #5
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = Encoder(pretrained_path=_config['path']['init_path'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    if not _config['notrain']:
        model.load_state_dict(
            torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()

    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    if _config['scribble_dilation'] > 0:
        transforms.append(DilateScribble(size=_config['scribble_dilation']))
    transforms = Compose(transforms)

    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                label_sets=_config['label_sets'],
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'])
            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")

            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [
                        coco_cls_ids.index(x) + 1
                        for x in sample_batched['class_ids']
                    ]
                else:
                    label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                support_images = torch.cat(
                    [torch.cat(way, dim=0) for way in support_images], dim=0)
                suffix = 'scribble' if _config['scribble'] else 'mask'

                if _config['bbox']:
                    support_fg_mask = []
                    support_bg_mask = []
                    for i, way in enumerate(sample_batched['support_mask']):
                        fg_masks = []
                        bg_masks = []
                        for j, shot in enumerate(way):
                            fg_mask, bg_mask = get_bbox(
                                shot['fg_mask'],
                                sample_batched['support_inst'][i][j])
                            fg_masks.append(fg_mask.float().cuda())
                            bg_masks.append(bg_mask.float().cuda())
                        support_fg_mask.append(fg_masks)
                        support_bg_mask.append(bg_masks)
                else:
                    support_fg_mask = [[
                        shot[f'fg_mask'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]
                    support_fg_mask = torch.cat(
                        [torch.cat(way, dim=0) for way in support_fg_mask],
                        dim=0)

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_images = torch.cat(query_images, dim=0)

                query_labels = torch.cat([
                    query_label.long().cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)

                query_pred = model(support_images, query_images,
                                   support_fg_mask)
                query_pred = F.interpolate(query_pred,
                                           size=query_images.shape[-2:],
                                           mode='bilinear')

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids,
                              n_run=run)

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
Beispiel #6
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_objectness = ModelBuilder.build_objectness(
        arch=cfg.MODEL.arch_objectness,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=2,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    crit = nn.NLLLoss(ignore_index=255)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized_objectness import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized_objectness import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    exclude_labels = labels_val
    transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()])
    dataset = make_data(base_dir=cfg.DATASET.data_dir,
                        split=cfg.DATASET.data_split,
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.TRAIN.permute_labels,
                        exclude_labels=exclude_labels)
    trainloader = DataLoader(dataset,
                             batch_size=cfg.TRAIN.n_batch,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=True)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    net_objectness.cuda()
    net_decoder.cuda()

    # Set up optimizers
    nets = (net_objectness, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    net_objectness.train(not cfg.TRAIN.fix_bn)
    net_decoder.train(not cfg.TRAIN.fix_bn)

    best_iou = 0
    # main loop
    tic = time.time()

    print('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        feed_dict = data_preprocess(sample_batched, cfg)

        data_time.update(time.time() - tic)
        net_objectness.zero_grad()
        net_decoder.zero_grad()

        # adjust learning rate
        adjust_learning_rate(optimizers, i_iter, cfg)

        # forward pass
        feat = net_objectness(feed_dict['img_data'], return_feature_maps=True)
        pred = net_decoder(feat)
        loss = crit(pred, feed_dict['seg_label'])
        acc = pixel_acc(pred, feed_dict['seg_label'])
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            if optimizer:
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        # calculate accuracy, and display
        if i_iter % cfg.TRAIN.disp_iter == 0:
            print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

            history['train']['iter'].append(i_iter)
            history['train']['loss'].append(loss.data.item())
            history['train']['acc'].append(acc.data.item())

        if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, i_iter + 1)

        if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
            metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
            with torch.no_grad():
                print('----Evaluation----')
                net_objectness.eval()
                net_decoder.eval()
                net_decoder.use_softmax = True
                for run in range(cfg.VAL.n_runs):
                    print(f'### Run {run + 1} ###')
                    set_seed(cfg.VAL.seed + run)

                    print(f'### Load validation data ###')
                    dataset_val = make_data(base_dir=cfg.DATASET.data_dir,
                                            split=cfg.DATASET.data_split,
                                            transforms=transforms,
                                            to_tensor=ToTensorNormalize(),
                                            labels=labels_val,
                                            max_iters=cfg.VAL.n_iters *
                                            cfg.VAL.n_batch,
                                            n_ways=cfg.TASK.n_ways,
                                            n_shots=cfg.TASK.n_shots,
                                            n_queries=cfg.TASK.n_queries,
                                            permute=cfg.VAL.permute_labels,
                                            exclude_labels=[])
                    if data_name == 'COCO':
                        coco_cls_ids = dataset_val.datasets[
                            0].dataset.coco.getCatIds()
                    testloader = DataLoader(dataset_val,
                                            batch_size=cfg.VAL.n_batch,
                                            shuffle=False,
                                            num_workers=1,
                                            pin_memory=True,
                                            drop_last=False)
                    print(f"Total # of validation Data: {len(dataset)}")

                    #for sample_batched in tqdm.tqdm(testloader):
                    for sample_batched in testloader:
                        feed_dict = data_preprocess(sample_batched,
                                                    cfg,
                                                    is_val=True)
                        if data_name == 'COCO':
                            label_ids = [
                                coco_cls_ids.index(x) + 1
                                for x in sample_batched['class_ids']
                            ]
                        else:
                            label_ids = list(sample_batched['class_ids'])

                        feat = net_objectness(feed_dict['img_data'],
                                              return_feature_maps=True)
                        query_pred = net_decoder(
                            feat, segSize=cfg.DATASET.input_size)
                        metric.record(
                            np.array(query_pred.argmax(dim=1)[0].cpu()),
                            np.array(feed_dict['seg_label'][0].cpu()),
                            labels=label_ids,
                            n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(
                        labels=sorted(labels_val), n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

            classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
                labels=sorted(labels_val))
            classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
            )

            print('----- Evaluation Result -----')
            print(f'best meanIoU_binary: {best_iou}')
            print(f'meanIoU mean: {meanIoU}')
            print(f'meanIoU std: {meanIoU_std}')
            print(f'meanIoU_binary mean: {meanIoU_binary}')
            print(f'meanIoU_binary std: {meanIoU_std_binary}')

            if meanIoU_binary > best_iou:
                best_iou = meanIoU_binary
                checkpoint(nets, history, cfg, 'best')
            net_objectness.train(not cfg.TRAIN.fix_bn)
            net_decoder.train(not cfg.TRAIN.fix_bn)
            net_decoder.use_softmax = False

    print('Training Done!')
Beispiel #7
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'ScanNet':
        make_data = scannet_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=20,
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_coords = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_coords']]
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_coords = [
            query_coord.cuda()
            for query_coord in sample_batched['query_coords']
        ]
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)

        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_coords, support_images,
                                       support_fg_mask, support_bg_mask,
                                       query_coords, query_images)
        query_loss = criterion(query_pred, query_labels[None, ...])
        loss = query_loss + align_loss * _config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(
                model.state_dict(),
                os.path.join(f'{_run.observers[0].dir}/snapshots',
                             f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(
        model.state_dict(),
        os.path.join(f'{_run.observers[0].dir}/snapshots',
                     f'{i_iter + 1}.pth'))
Beispiel #8
0
def main(config):
    if not os.path.exists(config['snapshots']):
        os.makedirs(config['snapshots'])
    snap_shots_dir = config['snapshots']
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(2)
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True

    torch.set_num_threads(1)

    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[2])
    model.train()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][config['label_sets']]
    transforms = Compose([Resize(size=config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=config['path'][data_name]['data_dir'],
                        split=config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=config['n_steps'] * config['batch_size'],
                        n_ways=config['task']['n_ways'],
                        n_shots=config['task']['n_shots'],
                        n_queries=config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)

    optimizer = torch.optim.SGD(model.parameters(), **config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}

    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_fg_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_bg_mask']]
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)
        # print(query_labels.shape)
        pre_masks = [
            query_label.float().cuda()
            for query_label in sample_batched['query_masks']
        ]
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask,
                                       support_bg_mask, query_images,
                                       pre_masks)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * config['align_loss_scaler']

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        # _run.log_scalar('loss', query_loss)
        # _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}')

        if (i_iter + 1) % config['save_pred_every'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))

    torch.save(model.state_dict(),
               os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
Beispiel #9
0
def main(config):
    if not os.path.exists(config['snapshots']):
        os.makedirs(config['snapshots'])
    palette_path = config['palette_dir']
    with open(palette_path) as f:
        palette = f.readlines()
    palette = list(np.asarray([[int(p) for p in pal[0:-1].split(' ')] for pal in palette]).reshape(768))
    snap_shots_dir = config['snapshots']
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(2)
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True

    torch.set_num_threads(1)

    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(),device_ids=[2])
    model.train()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][config['label_sets']]
    transforms = Compose([Resize(size=config['input_size']),
                          RandomMirror()])
    dataset = make_data(
        base_dir=config['path'][data_name]['data_dir'],
        split=config['path'][data_name]['data_split'],
        transforms=transforms,
        to_tensor=ToTensorNormalize(),
        labels=labels,
        max_iters=config['n_steps'] * config['batch_size'],
        n_ways=config['task']['n_ways'],
        n_shots=config['task']['n_shots'],
        n_queries=config['task']['n_queries']
    )
    trainloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True
    )

    optimizer = torch.optim.SGD(model.parameters(), **config['optim'])
    scheduler = MultiStepLR(optimizer, milestones=config['lr_milestones'], gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}

    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_fg_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_bg_mask']]
        img_size = sample_batched['img_size']
        # support_label_t = [[shot.float().cuda() for shot in way]
        #                    for way in sample_batched['support_bg_mask']]
        query_images = [query_image.cuda()
                        for query_image in sample_batched['query_images']]
        query_labels = torch.cat(
            [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0)
        pre_masks = [query_label.float().cuda() for query_label in sample_batched['query_masks']]
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask, support_bg_mask,
                                       query_images,pre_masks)
        query_pred = F.interpolate(query_pred, size=img_size, mode= "bilinear")
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * config['align_loss_scaler']

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0

        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}')
            # if len(support_fg_mask)>1:
            #     pred = query_pred.argmax(dim=1, keepdim=True)
            #     pred = pred.data.cpu().numpy()
            #     img = pred[0, 0]
            #     for i in range(img.shape[0]):
            #         for j in range(img.shape[1]):
            #             if img[i][j] > 0:
            #                 print(f'{img[i][j]} {len(support_fg_mask)}')
            #
            #     img_e = Image.fromarray(img.astype('float32')).convert('P')
            #     img_e.putpalette(palette)
            #     img_e.save(os.path.join(config['path']['davis']['data_dir'], '{:05d}.png'.format(i_iter)))

        if (i_iter + 1) % config['save_pred_every'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))

    torch.save(model.state_dict(),
               os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
Beispiel #10
0
def main(cfg, gpus):
    torch.cuda.set_device(gpus[0])

    # Network Builders
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_enc_memory = ModelBuilder.build_encoder_memory_separate(
        arch=cfg.MODEL.arch_memory_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.TASK.n_ways+1,
        RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate)    
    net_att_query = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.fc_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_projection = ModelBuilder.build_projection(
        arch=cfg.MODEL.arch_projection,
        input_dim=cfg.MODEL.projection_input_dim,
        fc_dim=cfg.MODEL.projection_dim,
        weights=cfg.MODEL.weights_projection)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=cfg.TASK.n_ways+1,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout,
        use_softmax=True)

    crit = nn.NLLLoss(ignore_index=255)

    segmentation_module = SegmentationAttentionSeparateModule(net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, crit, zero_memory=cfg.MODEL.zero_memory, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, debug=cfg.is_debug or cfg.eval_att_voting, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL.linear_classifier_support_only)

    segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    segmentation_module.cuda()
    segmentation_module.eval()


    print('###### Prepare data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    transforms = [Resize(size=cfg.DATASET.input_size)]
    transforms = Compose(transforms)


    print('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
    with torch.no_grad():
        for run in range(cfg.VAL.n_runs):
            print(f'### Run {run + 1} ###')
            set_seed(cfg.VAL.seed + run)

            print(f'### Load data ###')
            dataset = make_data(
                base_dir=cfg.DATASET.data_dir,
                split=cfg.DATASET.data_split,
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch,
                n_ways=cfg.TASK.n_ways,
                n_shots=cfg.TASK.n_shots,
                n_queries=cfg.TASK.n_queries,
                permute=cfg.VAL.permute_labels
            )
            if data_name == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False,
                                    num_workers=1, pin_memory=True, drop_last=False)
            print(f"Total # of Data: {len(dataset)}")

            count = 0

            for sample_batched in tqdm.tqdm(testloader):
                feed_dict = data_preprocess(sample_batched, cfg)
                if data_name == 'COCO':
                    label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']]
                else:
                    label_ids = list(sample_batched['class_ids'])

                if cfg.eval_att_voting or cfg.is_debug:
                    query_pred, qread, qval, qk_b, mk_b, mv_b, p, feature_enc, feature_memory = segmentation_module(feed_dict, segSize=cfg.DATASET.input_size)
                    if cfg.eval_att_voting:
                        height, width = qread.shape[-2], qread.shape[-1]
                        assert p.shape[0] == height*width
                        img_refs_mask_resize = nn.functional.interpolate(feed_dict['img_refs_mask'][0].cuda(), size=(height, width), mode='nearest')
                        img_refs_mask_resize_flat = img_refs_mask_resize[:,0,:,:].view(img_refs_mask_resize.shape[0], -1)
                        mask_voting_flat = torch.mm(img_refs_mask_resize_flat, p)
                        mask_voting = mask_voting_flat.view(mask_voting_flat.shape[0], height, width)
                        mask_voting = torch.unsqueeze(mask_voting, 0)
                        query_pred = nn.functional.interpolate(mask_voting[:,0:-1], size=cfg.DATASET.input_size, mode='bilinear', align_corners=False)
                        if cfg.is_debug:
                            np.save('debug/img_refs_mask-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), img_refs_mask_resize.detach().cpu().float().numpy())
                            np.save('debug/query_pred-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), query_pred.detach().cpu().float().numpy())
                    if cfg.is_debug:
                        np.save('debug/qread-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qread.detach().cpu().float().numpy())
                        np.save('debug/qval-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qval.detach().cpu().float().numpy())
                        #np.save('debug/qk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qk_b.detach().cpu().float().numpy())
                        #np.save('debug/mk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mk_b.detach().cpu().float().numpy())
                        #np.save('debug/mv_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mv_b.detach().cpu().float().numpy())
                        np.save('debug/p-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), p.detach().cpu().float().numpy())
                        #np.save('debug/feature_enc-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_enc[-1].detach().cpu().float().numpy())
                        #np.save('debug/feature_memory-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_memory[-1].detach().cpu().float().numpy())
                else:
                    query_pred = segmentation_module(feed_dict, segSize=cfg.DATASET.input_size)

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(feed_dict['seg_label'][0].cpu()),
                              labels=label_ids, n_run=run)

                if cfg.VAL.visualize:
                    #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape)
                    #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape)
                    #print(feed_dict['img_data'].cpu().shape)
                    query_name = sample_batched['query_ids'][0][0]
                    support_name = sample_batched['support_ids'][0][0][0]
                    img = imread(os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name+'.jpg'))
                    img = imresize(img, cfg.DATASET.input_size)
                    visualize_result(
                        (img, as_numpy(feed_dict['seg_label'][0].cpu()), '%05d'%(count)),
                        as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())),
                        os.path.join(cfg.DIR, 'result')
                    )
                count += 1

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            '''_run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')'''

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary()

    print('----- Final Result -----')
    print('final_classIoU', classIoU.tolist())
    print('final_classIoU_std', classIoU_std.tolist())
    print('final_meanIoU', meanIoU.tolist())
    print('final_meanIoU_std', meanIoU_std.tolist())
    print('final_classIoU_binary', classIoU_binary.tolist())
    print('final_classIoU_std_binary', classIoU_std_binary.tolist())
    print('final_meanIoU_binary', meanIoU_binary.tolist())
    print('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    print(f'classIoU mean: {classIoU}')
    print(f'classIoU std: {classIoU_std}')
    print(f'meanIoU mean: {meanIoU}')
    print(f'meanIoU std: {meanIoU_std}')
    print(f'classIoU_binary mean: {classIoU_binary}')
    print(f'classIoU_binary std: {classIoU_std_binary}')
    print(f'meanIoU_binary mean: {meanIoU_binary}')
    print(f'meanIoU_binary std: {meanIoU_std_binary}')
Beispiel #11
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_enc_memory = ModelBuilder.build_encoder_memory_separate(
        arch=cfg.MODEL.arch_memory_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.TASK.n_ways + 1,
        RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate)
    net_att_query = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.fc_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_projection = ModelBuilder.build_projection(
        arch=cfg.MODEL.arch_projection,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.projection_dim,
        weights=cfg.MODEL.weights_projection)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.decoder_fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=cfg.TASK.n_ways + 1,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder:
        '''net_objectness = ModelBuilder.build_objectness(
            arch='resnet50_deeplab',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='aspp_few_shot',
            input_dim=2048,
            fc_dim=256,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            dropout_rate=0.5,
            use_dropout=True)'''
        net_objectness = ModelBuilder.build_objectness(
            arch='hrnetv2',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='c1_nodropout',
            input_dim=720,
            fc_dim=720,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            use_dropout=False)
        for param in net_objectness.parameters():
            param.requires_grad = False
        for param in net_objectness_decoder.parameters():
            param.requires_grad = False
    else:
        net_objectness = None
        net_objectness_decoder = None

    crit = nn.NLLLoss(ignore_index=255)

    segmentation_module = SegmentationAttentionSeparateModule(
        net_enc_query,
        net_enc_memory,
        net_att_query,
        net_att_memory,
        net_decoder,
        net_projection,
        net_objectness,
        net_objectness_decoder,
        crit,
        zero_memory=cfg.MODEL.zero_memory,
        random_memory_bias=cfg.MODEL.random_memory_bias,
        random_memory_nobias=cfg.MODEL.random_memory_nobias,
        random_scale=cfg.MODEL.random_scale,
        zero_qval=cfg.MODEL.zero_qval,
        normalize_key=cfg.MODEL.normalize_key,
        p_scalar=cfg.MODEL.p_scalar,
        memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation,
        memory_noLabel=cfg.MODEL.memory_noLabel,
        mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate,
        att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate,
        objectness_feat_downsample_rate=cfg.MODEL.
        objectness_feat_downsample_rate,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate,
        mask_foreground=cfg.MODEL.mask_foreground,
        global_pool_read=cfg.MODEL.global_pool_read,
        average_memory_voting=cfg.MODEL.average_memory_voting,
        average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm,
        mask_memory_RGB=cfg.MODEL.mask_memory_RGB,
        linear_classifier_support=cfg.MODEL.linear_classifier_support,
        decay_lamb=cfg.MODEL.decay_lamb,
        linear_classifier_support_only=cfg.MODEL.
        linear_classifier_support_only,
        qread_only=cfg.MODEL.qread_only,
        feature_as_key=cfg.MODEL.feature_as_key,
        objectness_multiply=cfg.MODEL.objectness_multiply)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized_objectness_debug import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized_objectness_debug import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    if cfg.DATASET.exclude_labels:
        exclude_labels = labels_val
    else:
        exclude_labels = []
    transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()])
    dataset = make_data(base_dir=cfg.DATASET.data_dir,
                        split=cfg.DATASET.data_split,
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.TRAIN.permute_labels,
                        exclude_labels=exclude_labels,
                        use_ignore=cfg.use_ignore)
    trainloader = DataLoader(dataset,
                             batch_size=cfg.TRAIN.n_batch,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=True)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory,
            net_decoder, net_projection, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    segmentation_module.train(not cfg.TRAIN.fix_bn)
    if net_objectness and net_objectness_decoder:
        net_objectness.eval()
        net_objectness_decoder.eval()

    best_iou = 0
    # main loop
    tic = time.time()

    print('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        feed_dict = data_preprocess(sample_batched, cfg)

        data_time.update(time.time() - tic)
        segmentation_module.zero_grad()

        # adjust learning rate
        adjust_learning_rate(optimizers, i_iter, cfg)

        # forward pass
        #print(batch_data)
        loss, acc = segmentation_module(feed_dict)
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            if optimizer:
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        # calculate accuracy, and display
        if i_iter % cfg.TRAIN.disp_iter == 0:
            print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

            history['train']['iter'].append(i_iter)
            history['train']['loss'].append(loss.data.item())
            history['train']['acc'].append(acc.data.item())

        if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, i_iter + 1)

        if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
            metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
            with torch.no_grad():
                print('----Evaluation----')
                segmentation_module.eval()
                net_decoder.use_softmax = True
                for run in range(cfg.VAL.n_runs):
                    print(f'### Run {run + 1} ###')
                    set_seed(cfg.VAL.seed + run)

                    print(f'### Load validation data ###')
                    dataset_val = make_data(base_dir=cfg.DATASET.data_dir,
                                            split=cfg.DATASET.data_split,
                                            transforms=transforms,
                                            to_tensor=ToTensorNormalize(),
                                            labels=labels_val,
                                            max_iters=cfg.VAL.n_iters *
                                            cfg.VAL.n_batch,
                                            n_ways=cfg.TASK.n_ways,
                                            n_shots=cfg.TASK.n_shots,
                                            n_queries=cfg.TASK.n_queries,
                                            permute=cfg.VAL.permute_labels,
                                            exclude_labels=[])
                    if data_name == 'COCO':
                        coco_cls_ids = dataset_val.datasets[
                            0].dataset.coco.getCatIds()
                    testloader = DataLoader(dataset_val,
                                            batch_size=cfg.VAL.n_batch,
                                            shuffle=False,
                                            num_workers=1,
                                            pin_memory=True,
                                            drop_last=False)
                    print(f"Total # of validation Data: {len(dataset)}")

                    #for sample_batched in tqdm.tqdm(testloader):
                    for sample_batched in testloader:
                        feed_dict = data_preprocess(sample_batched,
                                                    cfg,
                                                    is_val=True)
                        if data_name == 'COCO':
                            label_ids = [
                                coco_cls_ids.index(x) + 1
                                for x in sample_batched['class_ids']
                            ]
                        else:
                            label_ids = list(sample_batched['class_ids'])

                        query_pred = segmentation_module(
                            feed_dict, segSize=cfg.DATASET.input_size)
                        metric.record(
                            np.array(query_pred.argmax(dim=1)[0].cpu()),
                            np.array(feed_dict['seg_label'][0].cpu()),
                            labels=label_ids,
                            n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(
                        labels=sorted(labels_val), n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

            classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
                labels=sorted(labels_val))
            classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
            )

            print('----- Evaluation Result -----')
            print(f'best meanIoU mean: {best_iou}')
            print(f'meanIoU mean: {meanIoU}')
            print(f'meanIoU std: {meanIoU_std}')
            print(f'meanIoU_binary mean: {meanIoU_binary}')
            print(f'meanIoU_binary std: {meanIoU_std_binary}')

            checkpoint(nets, history, cfg, 'latest')

            if meanIoU > best_iou:
                best_iou = meanIoU
                checkpoint(nets, history, cfg, 'best')
            segmentation_module.train(not cfg.TRAIN.fix_bn)
            if net_objectness and net_objectness_decoder:
                net_objectness.eval()
                net_objectness_decoder.eval()
            net_decoder.use_softmax = False

    print('Training Done!')
Beispiel #12
0
def main(_run, _config, _log):
    os.makedirs(f'{_run.observers[0].dir}/features', exist_ok=True)
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                    exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)


    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'], task=_config['task'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    if not _config['notrain']:
        model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()


    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')

    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    transforms = Compose(transforms)


    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)
            features_dfs = []

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries']
            )

            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()

            testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False,
                                    num_workers=1, pin_memory=True, drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")


            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']]
                else:
                    label_ids = list(sample_batched['class_ids'])
                
                support_ids = [[sample_batched['support_ids'][way*_config['task']['n_shots'] + shot][0] 
                                for shot in range(_config['task']['n_shots'])]
                                for way in range(_config['task']['n_ways'])]
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]

                support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                                    for way in sample_batched['support_mask']]
                support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                                    for way in sample_batched['support_mask']]

                query_images = [query_image.cuda()
                                for query_image in sample_batched['query_images']]
                query_labels = torch.cat(
                    [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0)

                query_pred, _, supp_fts = model(support_images, support_fg_mask, support_bg_mask,
                                      query_images)

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids, n_run=run)
                
                # Save features row
                for i, label_id in enumerate(label_ids):
                    lbl_df = pd.DataFrame(torch.cat(supp_fts[i]).cpu().numpy())
                    lbl_df['label'] = label_id.item()
                    lbl_df['id'] = pd.Series(support_ids[i])
                    features_dfs.append(lbl_df)
                

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

            _log.info('### Exporting features CSV')
            features_df = pd.concat(features_dfs)
            features_df = features_df.drop_duplicates(subset=['id'])
            cols = list(features_df)
            cols = [cols[-1], cols[-2]] + cols[:-2]
            features_df = features_df[cols]
            features_df.to_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv', index=False)

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary()

    _log.info('###### Saving features visualization ######')
    all_fts = pd.concat([pd.read_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv') for run in range(_config['n_runs'])])
    all_fts = all_fts.drop_duplicates(subset=['id'])

    _log.info('### Obtaining Umap visualization ###')
    plot_umap(all_fts, f'{_run.observers[0].dir}/features/Umap_fts.png')

    _log.info('### Obtaining TSNE visualization ###')
    plot_tsne(all_fts, f'{_run.observers[0].dir}/features/TSNE_fts.png')


    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
Beispiel #13
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = Encoder(pretrained_path=_config['path']['init_path'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
    elif data_name == 'COCO':
        make_data = coco_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        label_sets=_config['label_sets'],
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'mcl_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        #support image,support mask label and support multi-class label
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_images = torch.cat(
            [torch.cat(way, dim=0) for way in support_images], dim=0)
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_fg_mask = torch.cat(
            [torch.cat(way, dim=0) for way in support_fg_mask], dim=0)
        support_label_mc = [[shot[f'label_ori'].long().cuda() for shot in way]
                            for way in sample_batched['support_mask']]
        support_label_mc = torch.cat(
            [torch.cat(way, dim=0) for way in support_label_mc], dim=0)

        #query image,query mask label and query multi-class label
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_images = torch.cat(query_images, dim=0)
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)
        query_label_mc = [
            n_queries[f'label_ori'].long().cuda()
            for n_queries in sample_batched['query_masks']
        ]
        query_label_mc = torch.cat(query_label_mc, dim=0)

        optimizer.zero_grad()
        query_pred, support_pred_mc, query_pred_mc, support_pred = model(
            support_images, query_images, support_fg_mask)
        query_pred = F.interpolate(query_pred,
                                   size=query_images.shape[-2:],
                                   mode='bilinear')
        support_pred = F.interpolate(support_pred,
                                     size=support_images.shape[-2:],
                                     mode='bilinear')
        support_pred_mc = F.interpolate(support_pred_mc,
                                        size=support_images.shape[-2:],
                                        mode='bilinear')
        query_pred_mc = F.interpolate(query_pred_mc,
                                      size=query_images.shape[-2:],
                                      mode='bilinear')

        binary_loss = criterion(query_pred, query_labels) + criterion(
            support_pred,
            support_fg_mask.long().cuda())
        mcl_loss = criterion(support_pred_mc, support_label_mc) + criterion(
            query_pred_mc, query_label_mc)

        loss = binary_loss + mcl_loss * _config['mcl_loss_scaler']
        loss.backward()

        optimizer.step()
        scheduler.step()

        # Log loss
        binary_loss = binary_loss.detach().data.cpu().numpy()
        mcl_loss = mcl_loss.detach().data.cpu().numpy() if mcl_loss != 0 else 0
        _run.log_scalar('loss', binary_loss)
        _run.log_scalar('mcl_loss', mcl_loss)
        log_loss['loss'] += binary_loss
        log_loss['mcl_loss'] += mcl_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            mcl_loss = log_loss['mcl_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, mcl_loss: {mcl_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(
                model.state_dict(),
                os.path.join(f'{_run.observers[0].dir}/snapshots',
                             f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(
        model.state_dict(),
        os.path.join(f'{_run.observers[0].dir}/snapshots',
                     f'{i_iter + 1}.pth'))
Beispiel #14
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    if not _config['notrain']:
        model.load_state_dict(
            torch.load(_config['snapshot'], map_location='cpu'))
    print("Snapshotttt")
    print(_config['snapshot'])
    model.eval()

    # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2net' + '/' + 'u2net_bce_itr_3168_train_1.523160_tar_0.203136.pth'
    # u2_net = U2NET(3,1)
    # u2_net.load_state_dict(torch.load(u2_model_dir))

    # if torch.cuda.is_available():
    #     u2_net.cuda()

    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    if _config['scribble_dilation'] > 0:
        transforms.append(DilateScribble(size=_config['scribble_dilation']))
    transforms = Compose(transforms)

    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'])
            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")

            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [
                        coco_cls_ids.index(x) + 1
                        for x in sample_batched['class_ids']
                    ]
                else:
                    label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'scribble' if _config['scribble'] else 'mask'

                if _config['bbox']:
                    support_fg_mask = []
                    support_bg_mask = []
                    for i, way in enumerate(sample_batched['support_mask']):
                        fg_masks = []
                        bg_masks = []
                        for j, shot in enumerate(way):
                            fg_mask, bg_mask = get_bbox(
                                shot['fg_mask'],
                                sample_batched['support_inst'][i][j])
                            fg_masks.append(fg_mask.float().cuda())
                            bg_masks.append(bg_mask.float().cuda())
                        support_fg_mask.append(fg_masks)
                        support_bg_mask.append(bg_masks)
                else:
                    support_fg_mask = [[
                        shot[f'fg_{suffix}'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]
                    support_bg_mask = [[
                        shot[f'bg_{suffix}'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_labels = torch.cat([
                    query_label.cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)

                # u2net
                inputs = query_images[0].type(torch.FloatTensor)
                labels_v = query_labels.type(torch.FloatTensor)
                if torch.cuda.is_available():
                    inputs_v, labels_v = Variable(
                        inputs.cuda(),
                        requires_grad=False), Variable(labels_v.cuda(),
                                                       requires_grad=False)
                else:
                    inputs_v, labels_v = Variable(
                        inputs,
                        requires_grad=False), Variable(labels_v,
                                                       requires_grad=False)
                #d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v)

                # normalization
                # pred = d1[:,0,:,:]
                # pred = normPRED(pred)
                pred = []

                query_pred, _, _ = model(support_images, support_fg_mask,
                                         support_bg_mask, query_images, pred)
                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids,
                              n_run=run)

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')