Пример #1
0
def main(_run, _config, _log):

    logdir = f'{_run.observers[0].dir}/'
    print(logdir)
    category = [
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
    ]

    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    data_name = _config['dataset']
    max_label = 20 if data_name == 'VOC' else 80

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    print(_config['ckpt_dir'])
    tbwriter = SummaryWriter(osp.join(_config['ckpt_dir']))

    training_tags = {
        'loss': "ATraining/total_loss",
        "query_loss": "ATraining/query_loss",
        'aligned_loss': "ATraining/aligned_loss",
        'base_loss': "ATraining/base_loss",
    }
    infer_tags = {
        'mean_iou': "MeanIoU/mean_iou",
        "mean_iou_binary": "MeanIoU/mean_iou_binary",
    }

    _log.info('###### Create model ######')

    if _config['model']['part']:
        model = FewshotSegPartResnet(
            pretrained_path=_config['path']['init_path'], cfg=_config)
        _log.info('Model: FewshotSegPartResnet')

    else:
        model = FewshotSegResnet(pretrained_path=_config['path']['init_path'],
                                 cfg=_config)
        _log.info('Model: FewshotSegResnet')

    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    make_data = voc_fewshot
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'],
                        n_unlabel=_config['task']['n_unlabels'],
                        cfg=_config)
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=_config['num_workers'],
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    if _config['fix']:
        print('Optimizer: fix')
        optimizer = torch.optim.SGD(
            params=[
                {
                    "params": model.module.encoder.layer3.parameters(),
                    "lr": _config['optim']['lr'],
                    "weight_decay": _config['optim']['weight_decay']
                },
                {
                    "params": model.module.encoder.layer4.parameters(),
                    "lr": _config['optim']['lr'],
                    "weight_decay": _config['optim']['weight_decay']
                },
            ],
            momentum=_config['optim']['momentum'],
        )
    else:
        print('Optimizer: Not fix')
        optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    log_loss = {'loss': 0, 'align_loss': 0, 'base_loss': 0}
    _log.info('###### Training ######')

    highest_iou = 0
    metrics = {}

    for i_iter, sample_batched in enumerate(trainloader):
        if _config['fix']:
            model.module.encoder.conv1.eval()
            model.module.encoder.bn1.eval()
            model.module.encoder.layer1.eval()
            model.module.encoder.layer2.eval()

        if _config['eval']:
            if i_iter == 0:
                break
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)  #1*417*417

        base_loss = torch.zeros(1).to(torch.device('cuda'))
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, _, align_loss = model(support_images, support_fg_mask,
                                          support_bg_mask, query_images)
        query_loss = criterion(query_pred,
                               query_labels)  #1*3*417*417, 1*417*417
        loss = query_loss + align_loss * _config[
            'align_loss_scaler'] + base_loss * _config['base_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy()
        base_loss = base_loss.detach().data.cpu().numpy()

        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss
        log_loss['base_loss'] += base_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            base_loss = log_loss['base_loss'] / (i_iter + 1)

            print(
                f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}'
            )
            _log.info(
                f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}'
            )

            metrics['loss'] = loss
            metrics['query_loss'] = query_loss
            metrics['aligned_loss'] = align_loss
            metrics['base_loss'] = base_loss

            # for k, v in metrics.items():
            #     tbwriter.add_scalar(training_tags[k], v, i_iter)

        if (i_iter + 1) % _config['evaluate_interval'] == 0:
            _log.info('###### Evaluation begins ######')
            print(_config['ckpt_dir'])

            model.eval()

            labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
                _config['label_sets']]
            transforms = [Resize(size=_config['input_size'])]
            transforms = Compose(transforms)

            metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
            with torch.no_grad():
                for run in range(1):
                    _log.info(f'### Run {run + 1} ###')
                    set_seed(_config['seed'] + run)

                    _log.info(f'### Load data ###')
                    dataset = make_data(
                        base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['infer_max_iters'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'],
                        n_unlabel=_config['task']['n_unlabels'],
                        cfg=_config)
                    testloader = DataLoader(dataset,
                                            batch_size=_config['batch_size'],
                                            shuffle=False,
                                            num_workers=_config['num_workers'],
                                            pin_memory=True,
                                            drop_last=False)
                    _log.info(f"Total # of Data: {len(dataset)}")

                    for sample_batched in tqdm.tqdm(testloader):
                        label_ids = list(sample_batched['class_ids'])
                        support_images = [[
                            shot.cuda() for shot in way
                        ] for way in sample_batched['support_images']]
                        suffix = 'mask'
                        support_fg_mask = [[
                            shot[f'fg_{suffix}'].float().cuda() for shot in way
                        ] for way in sample_batched['support_mask']]
                        support_bg_mask = [[
                            shot[f'bg_{suffix}'].float().cuda() for shot in way
                        ] for way in sample_batched['support_mask']]

                        query_images = [
                            query_image.cuda()
                            for query_image in sample_batched['query_images']
                        ]
                        query_labels = torch.cat([
                            query_label.cuda()
                            for query_label in sample_batched['query_labels']
                        ],
                                                 dim=0)
                        query_pred, _, _ = model(support_images,
                                                 support_fg_mask,
                                                 support_bg_mask, query_images)
                        curr_iou = metric.record(query_pred.argmax(dim=1)[0],
                                                 query_labels[0],
                                                 labels=label_ids,
                                                 n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                        n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

                    _run.log_scalar('classIoU', classIoU.tolist())
                    _run.log_scalar('meanIoU', meanIoU.tolist())
                    _run.log_scalar('classIoU_binary',
                                    classIoU_binary.tolist())
                    _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
                    _log.info(f'classIoU: {classIoU}')
                    _log.info(f'meanIoU: {meanIoU}')
                    _log.info(f'classIoU_binary: {classIoU_binary}')
                    _log.info(f'meanIoU_binary: {meanIoU_binary}')

                    print(
                        f'meanIoU: {meanIoU}, meanIoU_binary: {meanIoU_binary}'
                    )

                    metrics = {}
                    metrics['mean_iou'] = meanIoU
                    metrics['mean_iou_binary'] = meanIoU_binary

                    for k, v in metrics.items():
                        tbwriter.add_scalar(infer_tags[k], v, i_iter)

                    if meanIoU > highest_iou:
                        print(
                            f'The highest iou is in iter: {i_iter} : {meanIoU}, save: {_config["ckpt_dir"]}/best.pth'
                        )
                        highest_iou = meanIoU
                        torch.save(
                            model.state_dict(),
                            os.path.join(f'{_config["ckpt_dir"]}/best.pth'))
                    else:
                        print(
                            f'The highest iou is in iter: {i_iter} : {meanIoU}'
                        )
            torch.save(model.state_dict(),
                       os.path.join(f'{_config["ckpt_dir"]}/{i_iter + 1}.pth'))
        model.train()

    print(_config['ckpt_dir'])

    _log.info(' --------- Testing begins ---------')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    transforms = Compose(transforms)
    ckpt = os.path.join(f'{_config["ckpt_dir"]}/best.pth')
    print(f'{_config["ckpt_dir"]}/best.pth')

    model.load_state_dict(torch.load(ckpt, map_location='cpu'))
    model.eval()

    metric = Metric(max_label=max_label, n_runs=5)
    with torch.no_grad():
        for run in range(5):
            n_iter = 0
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['infer_max_iters'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'],
                n_unlabel=_config['task']['n_unlabels'],
                cfg=_config)
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=_config['num_workers'],
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")
            for sample_batched in tqdm.tqdm(testloader):
                label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'mask'
                support_fg_mask = [[
                    shot[f'fg_{suffix}'].float().cuda() for shot in way
                ] for way in sample_batched['support_mask']]
                support_bg_mask = [[
                    shot[f'bg_{suffix}'].float().cuda() for shot in way
                ] for way in sample_batched['support_mask']]

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_labels = torch.cat([
                    query_label.cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)
                query_pred, _, _ = model(support_images, support_fg_mask,
                                         support_bg_mask, query_images)
                curr_iou = metric.record(query_pred.argmax(dim=1)[0],
                                         query_labels[0],
                                         labels=label_ids,
                                         n_run=run)
                n_iter += 1

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _run.log_scalar('meanIoU', meanIoU.tolist())
    _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())

    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())

    _log.info('----- Final Result -----')
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')

    _log.info("## ------------------------------------------ ##")
    _log.info(f'###### Setting: {_run.observers[0].dir} ######')

    _log.info(
        "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} "
        "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}".
        format(num_run=5,
               miou=meanIoU,
               mbiou=meanIoU_binary,
               miou_std=meanIoU_std,
               mbiou_std=meanIoU_std_binary))
    _log.info(f"Current setting is {_run.observers[0].dir}")

    print(
        "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} "
        "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}".
        format(num_run=5,
               miou=meanIoU,
               mbiou=meanIoU_binary,
               miou_std=meanIoU_std,
               mbiou_std=meanIoU_std_binary))
    print(f"Current setting is {_run.observers[0].dir}")
    print(_config['ckpt_dir'])
    print(logdir)
Пример #2
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'ScanNet':
        make_data = scannet_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=20,
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_coords = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_coords']]
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_coords = [
            query_coord.cuda()
            for query_coord in sample_batched['query_coords']
        ]
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)

        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_coords, support_images,
                                       support_fg_mask, support_bg_mask,
                                       query_coords, query_images)
        query_loss = criterion(query_pred, query_labels[None, ...])
        loss = query_loss + align_loss * _config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(
                model.state_dict(),
                os.path.join(f'{_run.observers[0].dir}/snapshots',
                             f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(
        model.state_dict(),
        os.path.join(f'{_run.observers[0].dir}/snapshots',
                     f'{i_iter + 1}.pth'))
Пример #3
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')


    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)


    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    model.train()

    # Using Saliency
    # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2netp' + '/' + 'u2netp.pth'
    # u2_net = U2NETP(3,1)
    # u2_net.load_state_dict(torch.load(u2_model_dir))
    
    # if torch.cuda.is_available():
    #     u2_net.cuda()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
    elif data_name == 'COCO':
        make_data = coco_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']),
                          RandomMirror()])
    dataset = make_data(
        base_dir=_config['path'][data_name]['data_dir'],
        split=_config['path'][data_name]['data_split'],
        transforms=transforms,
        to_tensor=ToTensorNormalize(),
        labels=labels,
        max_iters=_config['n_steps'] * _config['batch_size'],
        n_ways=_config['task']['n_ways'],
        n_shots=_config['task']['n_shots'],
        n_queries=_config['task']['n_queries']
    )
    trainloader = DataLoader(
        dataset,
        batch_size=_config['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True
    )

    _log.info('###### Set optimizer ######')
    print(_config['mode'])
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0, 'dist_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]

        query_images = [query_image.cuda()
                        for query_image in sample_batched['query_images']]
        query_labels = torch.cat(
            [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0)
        # Forward and Backward
        optimizer.zero_grad()
        
        # with torch.no_grad():
        #   # u2net
        #   inputs = query_images[0].type(torch.FloatTensor)
        #   labels = query_labels.type(torch.FloatTensor)
        #   if torch.cuda.is_available():
        #       inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
        #                                                                                   requires_grad=False)
        #   else:
        #       inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
        #   d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v)

        #   # normalization
        #   pred = d1[:,0,:,:]
        #   pred = normPRED(pred)
        pred = []

        query_pred, align_loss, dist_loss = model(support_images, support_fg_mask, support_bg_mask,
                                       query_images, pred)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + dist_loss + align_loss * 0.2 #_config['align_loss_scaler']
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        dist_loss = dist_loss.detach().data.cpu().numpy() if dist_loss != 0 else 0
        align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0
        _run.log_scalar('loss', query_loss)
        _run.log_scalar('align_loss', align_loss)
        _run.log_scalar('dist_loss', dist_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss
        log_loss['dist_loss'] += dist_loss



        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, dist_loss: {dist_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(model.state_dict(),
                       os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(model.state_dict(),
               os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
Пример #4
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_objectness = ModelBuilder.build_objectness(
        arch=cfg.MODEL.arch_objectness,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=2,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    crit = nn.NLLLoss(ignore_index=255)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized_objectness import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized_objectness import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    exclude_labels = labels_val
    transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()])
    dataset = make_data(base_dir=cfg.DATASET.data_dir,
                        split=cfg.DATASET.data_split,
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.TRAIN.permute_labels,
                        exclude_labels=exclude_labels)
    trainloader = DataLoader(dataset,
                             batch_size=cfg.TRAIN.n_batch,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=True)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    net_objectness.cuda()
    net_decoder.cuda()

    # Set up optimizers
    nets = (net_objectness, net_decoder, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    net_objectness.train(not cfg.TRAIN.fix_bn)
    net_decoder.train(not cfg.TRAIN.fix_bn)

    best_iou = 0
    # main loop
    tic = time.time()

    print('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        feed_dict = data_preprocess(sample_batched, cfg)

        data_time.update(time.time() - tic)
        net_objectness.zero_grad()
        net_decoder.zero_grad()

        # adjust learning rate
        adjust_learning_rate(optimizers, i_iter, cfg)

        # forward pass
        feat = net_objectness(feed_dict['img_data'], return_feature_maps=True)
        pred = net_decoder(feat)
        loss = crit(pred, feed_dict['seg_label'])
        acc = pixel_acc(pred, feed_dict['seg_label'])
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            if optimizer:
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        # calculate accuracy, and display
        if i_iter % cfg.TRAIN.disp_iter == 0:
            print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

            history['train']['iter'].append(i_iter)
            history['train']['loss'].append(loss.data.item())
            history['train']['acc'].append(acc.data.item())

        if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, i_iter + 1)

        if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
            metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
            with torch.no_grad():
                print('----Evaluation----')
                net_objectness.eval()
                net_decoder.eval()
                net_decoder.use_softmax = True
                for run in range(cfg.VAL.n_runs):
                    print(f'### Run {run + 1} ###')
                    set_seed(cfg.VAL.seed + run)

                    print(f'### Load validation data ###')
                    dataset_val = make_data(base_dir=cfg.DATASET.data_dir,
                                            split=cfg.DATASET.data_split,
                                            transforms=transforms,
                                            to_tensor=ToTensorNormalize(),
                                            labels=labels_val,
                                            max_iters=cfg.VAL.n_iters *
                                            cfg.VAL.n_batch,
                                            n_ways=cfg.TASK.n_ways,
                                            n_shots=cfg.TASK.n_shots,
                                            n_queries=cfg.TASK.n_queries,
                                            permute=cfg.VAL.permute_labels,
                                            exclude_labels=[])
                    if data_name == 'COCO':
                        coco_cls_ids = dataset_val.datasets[
                            0].dataset.coco.getCatIds()
                    testloader = DataLoader(dataset_val,
                                            batch_size=cfg.VAL.n_batch,
                                            shuffle=False,
                                            num_workers=1,
                                            pin_memory=True,
                                            drop_last=False)
                    print(f"Total # of validation Data: {len(dataset)}")

                    #for sample_batched in tqdm.tqdm(testloader):
                    for sample_batched in testloader:
                        feed_dict = data_preprocess(sample_batched,
                                                    cfg,
                                                    is_val=True)
                        if data_name == 'COCO':
                            label_ids = [
                                coco_cls_ids.index(x) + 1
                                for x in sample_batched['class_ids']
                            ]
                        else:
                            label_ids = list(sample_batched['class_ids'])

                        feat = net_objectness(feed_dict['img_data'],
                                              return_feature_maps=True)
                        query_pred = net_decoder(
                            feat, segSize=cfg.DATASET.input_size)
                        metric.record(
                            np.array(query_pred.argmax(dim=1)[0].cpu()),
                            np.array(feed_dict['seg_label'][0].cpu()),
                            labels=label_ids,
                            n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(
                        labels=sorted(labels_val), n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

            classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
                labels=sorted(labels_val))
            classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
            )

            print('----- Evaluation Result -----')
            print(f'best meanIoU_binary: {best_iou}')
            print(f'meanIoU mean: {meanIoU}')
            print(f'meanIoU std: {meanIoU_std}')
            print(f'meanIoU_binary mean: {meanIoU_binary}')
            print(f'meanIoU_binary std: {meanIoU_std_binary}')

            if meanIoU_binary > best_iou:
                best_iou = meanIoU_binary
                checkpoint(nets, history, cfg, 'best')
            net_objectness.train(not cfg.TRAIN.fix_bn)
            net_decoder.train(not cfg.TRAIN.fix_bn)
            net_decoder.use_softmax = False

    print('Training Done!')
Пример #5
0
def main(config):
    if not os.path.exists(config['snapshots']):
        os.makedirs(config['snapshots'])
    palette_path = config['palette_dir']
    with open(palette_path) as f:
        palette = f.readlines()
    palette = list(np.asarray([[int(p) for p in pal[0:-1].split(' ')] for pal in palette]).reshape(768))
    snap_shots_dir = config['snapshots']
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(2)
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True

    torch.set_num_threads(1)

    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(),device_ids=[2])
    model.train()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][config['label_sets']]
    transforms = Compose([Resize(size=config['input_size']),
                          RandomMirror()])
    dataset = make_data(
        base_dir=config['path'][data_name]['data_dir'],
        split=config['path'][data_name]['data_split'],
        transforms=transforms,
        to_tensor=ToTensorNormalize(),
        labels=labels,
        max_iters=config['n_steps'] * config['batch_size'],
        n_ways=config['task']['n_ways'],
        n_shots=config['task']['n_shots'],
        n_queries=config['task']['n_queries']
    )
    trainloader = DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=1,
        pin_memory=True,
        drop_last=True
    )

    optimizer = torch.optim.SGD(model.parameters(), **config['optim'])
    scheduler = MultiStepLR(optimizer, milestones=config['lr_milestones'], gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}

    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_fg_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_bg_mask']]
        img_size = sample_batched['img_size']
        # support_label_t = [[shot.float().cuda() for shot in way]
        #                    for way in sample_batched['support_bg_mask']]
        query_images = [query_image.cuda()
                        for query_image in sample_batched['query_images']]
        query_labels = torch.cat(
            [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0)
        pre_masks = [query_label.float().cuda() for query_label in sample_batched['query_masks']]
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask, support_bg_mask,
                                       query_images,pre_masks)
        query_pred = F.interpolate(query_pred, size=img_size, mode= "bilinear")
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * config['align_loss_scaler']

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0

        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}')
            # if len(support_fg_mask)>1:
            #     pred = query_pred.argmax(dim=1, keepdim=True)
            #     pred = pred.data.cpu().numpy()
            #     img = pred[0, 0]
            #     for i in range(img.shape[0]):
            #         for j in range(img.shape[1]):
            #             if img[i][j] > 0:
            #                 print(f'{img[i][j]} {len(support_fg_mask)}')
            #
            #     img_e = Image.fromarray(img.astype('float32')).convert('P')
            #     img_e.putpalette(palette)
            #     img_e.save(os.path.join(config['path']['davis']['data_dir'], '{:05d}.png'.format(i_iter)))

        if (i_iter + 1) % config['save_pred_every'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))

    torch.save(model.state_dict(),
               os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
Пример #6
0
def main(config):
    if not os.path.exists(config['snapshots']):
        os.makedirs(config['snapshots'])
    snap_shots_dir = config['snapshots']
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.cuda.set_device(2)
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True

    torch.set_num_threads(1)

    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[2])
    model.train()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][config['label_sets']]
    transforms = Compose([Resize(size=config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=config['path'][data_name]['data_dir'],
                        split=config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=config['n_steps'] * config['batch_size'],
                        n_ways=config['task']['n_ways'],
                        n_shots=config['task']['n_shots'],
                        n_queries=config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)

    optimizer = torch.optim.SGD(model.parameters(), **config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'align_loss': 0}

    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_fg_mask']]
        support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_bg_mask']]
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)
        # print(query_labels.shape)
        pre_masks = [
            query_label.float().cuda()
            for query_label in sample_batched['query_masks']
        ]
        # Forward and Backward
        optimizer.zero_grad()
        query_pred, align_loss = model(support_images, support_fg_mask,
                                       support_bg_mask, query_images,
                                       pre_masks)
        query_loss = criterion(query_pred, query_labels)
        loss = query_loss + align_loss * config['align_loss_scaler']

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss
        query_loss = query_loss.detach().data.cpu().numpy()
        align_loss = align_loss.detach().data.cpu().numpy(
        ) if align_loss != 0 else 0
        # _run.log_scalar('loss', query_loss)
        # _run.log_scalar('align_loss', align_loss)
        log_loss['loss'] += query_loss
        log_loss['align_loss'] += align_loss

        # print loss and take snapshots
        if (i_iter + 1) % config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            align_loss = log_loss['align_loss'] / (i_iter + 1)
            print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}')

        if (i_iter + 1) % config['save_pred_every'] == 0:
            torch.save(model.state_dict(),
                       os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))

    torch.save(model.state_dict(),
               os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
Пример #7
0
def main(cfg, gpus):
    # Network Builders
    torch.cuda.set_device(gpus[0])
    print('###### Create model ######')
    net_enc_query = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_query,
        fix_encoder=cfg.TRAIN.fix_encoder)
    net_enc_memory = ModelBuilder.build_encoder_memory_separate(
        arch=cfg.MODEL.arch_memory_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_enc_memory,
        num_class=cfg.TASK.n_ways + 1,
        RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate)
    net_att_query = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_query)
    net_att_memory = ModelBuilder.build_attention(
        arch=cfg.MODEL.arch_attention,
        input_dim=cfg.MODEL.fc_dim,
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_att_memory)
    net_projection = ModelBuilder.build_projection(
        arch=cfg.MODEL.arch_projection,
        input_dim=cfg.MODEL.encoder_dim,
        fc_dim=cfg.MODEL.projection_dim,
        weights=cfg.MODEL.weights_projection)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        input_dim=cfg.MODEL.decoder_dim,
        fc_dim=cfg.MODEL.decoder_fc_dim,
        ppm_dim=cfg.MODEL.ppm_dim,
        num_class=cfg.TASK.n_ways + 1,
        weights=cfg.MODEL.weights_decoder,
        dropout_rate=cfg.MODEL.dropout_rate,
        use_dropout=cfg.MODEL.use_dropout)

    if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder:
        '''net_objectness = ModelBuilder.build_objectness(
            arch='resnet50_deeplab',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='aspp_few_shot',
            input_dim=2048,
            fc_dim=256,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            dropout_rate=0.5,
            use_dropout=True)'''
        net_objectness = ModelBuilder.build_objectness(
            arch='hrnetv2',
            weights=cfg.MODEL.weights_objectness,
            fix_encoder=True)
        net_objectness_decoder = ModelBuilder.build_decoder(
            arch='c1_nodropout',
            input_dim=720,
            fc_dim=720,
            ppm_dim=256,
            num_class=2,
            weights=cfg.MODEL.weights_objectness_decoder,
            use_dropout=False)
        for param in net_objectness.parameters():
            param.requires_grad = False
        for param in net_objectness_decoder.parameters():
            param.requires_grad = False
    else:
        net_objectness = None
        net_objectness_decoder = None

    crit = nn.NLLLoss(ignore_index=255)

    segmentation_module = SegmentationAttentionSeparateModule(
        net_enc_query,
        net_enc_memory,
        net_att_query,
        net_att_memory,
        net_decoder,
        net_projection,
        net_objectness,
        net_objectness_decoder,
        crit,
        zero_memory=cfg.MODEL.zero_memory,
        random_memory_bias=cfg.MODEL.random_memory_bias,
        random_memory_nobias=cfg.MODEL.random_memory_nobias,
        random_scale=cfg.MODEL.random_scale,
        zero_qval=cfg.MODEL.zero_qval,
        normalize_key=cfg.MODEL.normalize_key,
        p_scalar=cfg.MODEL.p_scalar,
        memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation,
        memory_noLabel=cfg.MODEL.memory_noLabel,
        mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate,
        att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate,
        objectness_feat_downsample_rate=cfg.MODEL.
        objectness_feat_downsample_rate,
        segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate,
        mask_foreground=cfg.MODEL.mask_foreground,
        global_pool_read=cfg.MODEL.global_pool_read,
        average_memory_voting=cfg.MODEL.average_memory_voting,
        average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm,
        mask_memory_RGB=cfg.MODEL.mask_memory_RGB,
        linear_classifier_support=cfg.MODEL.linear_classifier_support,
        decay_lamb=cfg.MODEL.decay_lamb,
        linear_classifier_support_only=cfg.MODEL.
        linear_classifier_support_only,
        qread_only=cfg.MODEL.qread_only,
        feature_as_key=cfg.MODEL.feature_as_key,
        objectness_multiply=cfg.MODEL.objectness_multiply)

    print('###### Load data ######')
    data_name = cfg.DATASET.name
    if data_name == 'VOC':
        from dataloaders.customized_objectness_debug import voc_fewshot
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        from dataloaders.customized_objectness_debug import coco_fewshot
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx]
    labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        cfg.TASK.fold_idx]
    if cfg.DATASET.exclude_labels:
        exclude_labels = labels_val
    else:
        exclude_labels = []
    transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()])
    dataset = make_data(base_dir=cfg.DATASET.data_dir,
                        split=cfg.DATASET.data_split,
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch,
                        n_ways=cfg.TASK.n_ways,
                        n_shots=cfg.TASK.n_shots,
                        n_queries=cfg.TASK.n_queries,
                        permute=cfg.TRAIN.permute_labels,
                        exclude_labels=exclude_labels,
                        use_ignore=cfg.use_ignore)
    trainloader = DataLoader(dataset,
                             batch_size=cfg.TRAIN.n_batch,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=True)

    #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus)
    segmentation_module.cuda()

    # Set up optimizers
    nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory,
            net_decoder, net_projection, crit)
    optimizers = create_optimizers(nets, cfg)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    history = {'train': {'iter': [], 'loss': [], 'acc': []}}

    segmentation_module.train(not cfg.TRAIN.fix_bn)
    if net_objectness and net_objectness_decoder:
        net_objectness.eval()
        net_objectness_decoder.eval()

    best_iou = 0
    # main loop
    tic = time.time()

    print('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        # Prepare input
        feed_dict = data_preprocess(sample_batched, cfg)

        data_time.update(time.time() - tic)
        segmentation_module.zero_grad()

        # adjust learning rate
        adjust_learning_rate(optimizers, i_iter, cfg)

        # forward pass
        #print(batch_data)
        loss, acc = segmentation_module(feed_dict)
        loss = loss.mean()
        acc = acc.mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            if optimizer:
                optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        # calculate accuracy, and display
        if i_iter % cfg.TRAIN.disp_iter == 0:
            print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

            history['train']['iter'].append(i_iter)
            history['train']['loss'].append(loss.data.item())
            history['train']['acc'].append(acc.data.item())

        if (i_iter + 1) % cfg.TRAIN.save_freq == 0:
            checkpoint(nets, history, cfg, i_iter + 1)

        if (i_iter + 1) % cfg.TRAIN.eval_freq == 0:
            metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs)
            with torch.no_grad():
                print('----Evaluation----')
                segmentation_module.eval()
                net_decoder.use_softmax = True
                for run in range(cfg.VAL.n_runs):
                    print(f'### Run {run + 1} ###')
                    set_seed(cfg.VAL.seed + run)

                    print(f'### Load validation data ###')
                    dataset_val = make_data(base_dir=cfg.DATASET.data_dir,
                                            split=cfg.DATASET.data_split,
                                            transforms=transforms,
                                            to_tensor=ToTensorNormalize(),
                                            labels=labels_val,
                                            max_iters=cfg.VAL.n_iters *
                                            cfg.VAL.n_batch,
                                            n_ways=cfg.TASK.n_ways,
                                            n_shots=cfg.TASK.n_shots,
                                            n_queries=cfg.TASK.n_queries,
                                            permute=cfg.VAL.permute_labels,
                                            exclude_labels=[])
                    if data_name == 'COCO':
                        coco_cls_ids = dataset_val.datasets[
                            0].dataset.coco.getCatIds()
                    testloader = DataLoader(dataset_val,
                                            batch_size=cfg.VAL.n_batch,
                                            shuffle=False,
                                            num_workers=1,
                                            pin_memory=True,
                                            drop_last=False)
                    print(f"Total # of validation Data: {len(dataset)}")

                    #for sample_batched in tqdm.tqdm(testloader):
                    for sample_batched in testloader:
                        feed_dict = data_preprocess(sample_batched,
                                                    cfg,
                                                    is_val=True)
                        if data_name == 'COCO':
                            label_ids = [
                                coco_cls_ids.index(x) + 1
                                for x in sample_batched['class_ids']
                            ]
                        else:
                            label_ids = list(sample_batched['class_ids'])

                        query_pred = segmentation_module(
                            feed_dict, segSize=cfg.DATASET.input_size)
                        metric.record(
                            np.array(query_pred.argmax(dim=1)[0].cpu()),
                            np.array(feed_dict['seg_label'][0].cpu()),
                            labels=label_ids,
                            n_run=run)

                    classIoU, meanIoU = metric.get_mIoU(
                        labels=sorted(labels_val), n_run=run)
                    classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(
                        n_run=run)

            classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
                labels=sorted(labels_val))
            classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
            )

            print('----- Evaluation Result -----')
            print(f'best meanIoU mean: {best_iou}')
            print(f'meanIoU mean: {meanIoU}')
            print(f'meanIoU std: {meanIoU_std}')
            print(f'meanIoU_binary mean: {meanIoU_binary}')
            print(f'meanIoU_binary std: {meanIoU_std_binary}')

            checkpoint(nets, history, cfg, 'latest')

            if meanIoU > best_iou:
                best_iou = meanIoU
                checkpoint(nets, history, cfg, 'best')
            segmentation_module.train(not cfg.TRAIN.fix_bn)
            if net_objectness and net_objectness_decoder:
                net_objectness.eval()
                net_objectness_decoder.eval()
            net_decoder.use_softmax = False

    print('Training Done!')
Пример #8
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = Encoder(pretrained_path=_config['path']['init_path'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    model.train()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
    elif data_name == 'COCO':
        make_data = coco_fewshot
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name][_config['label_sets']]
    transforms = Compose([Resize(size=_config['input_size']), RandomMirror()])
    dataset = make_data(base_dir=_config['path'][data_name]['data_dir'],
                        split=_config['path'][data_name]['data_split'],
                        transforms=transforms,
                        to_tensor=ToTensorNormalize(),
                        labels=labels,
                        label_sets=_config['label_sets'],
                        max_iters=_config['n_steps'] * _config['batch_size'],
                        n_ways=_config['task']['n_ways'],
                        n_shots=_config['task']['n_shots'],
                        n_queries=_config['task']['n_queries'])
    trainloader = DataLoader(dataset,
                             batch_size=_config['batch_size'],
                             shuffle=True,
                             num_workers=1,
                             pin_memory=True,
                             drop_last=True)

    _log.info('###### Set optimizer ######')
    optimizer = torch.optim.SGD(model.parameters(), **_config['optim'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'])

    i_iter = 0
    log_loss = {'loss': 0, 'mcl_loss': 0}
    _log.info('###### Training ######')
    for i_iter, sample_batched in enumerate(trainloader):
        #support image,support mask label and support multi-class label
        support_images = [[shot.cuda() for shot in way]
                          for way in sample_batched['support_images']]
        support_images = torch.cat(
            [torch.cat(way, dim=0) for way in support_images], dim=0)
        support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                           for way in sample_batched['support_mask']]
        support_fg_mask = torch.cat(
            [torch.cat(way, dim=0) for way in support_fg_mask], dim=0)
        support_label_mc = [[shot[f'label_ori'].long().cuda() for shot in way]
                            for way in sample_batched['support_mask']]
        support_label_mc = torch.cat(
            [torch.cat(way, dim=0) for way in support_label_mc], dim=0)

        #query image,query mask label and query multi-class label
        query_images = [
            query_image.cuda()
            for query_image in sample_batched['query_images']
        ]
        query_images = torch.cat(query_images, dim=0)
        query_labels = torch.cat([
            query_label.long().cuda()
            for query_label in sample_batched['query_labels']
        ],
                                 dim=0)
        query_label_mc = [
            n_queries[f'label_ori'].long().cuda()
            for n_queries in sample_batched['query_masks']
        ]
        query_label_mc = torch.cat(query_label_mc, dim=0)

        optimizer.zero_grad()
        query_pred, support_pred_mc, query_pred_mc, support_pred = model(
            support_images, query_images, support_fg_mask)
        query_pred = F.interpolate(query_pred,
                                   size=query_images.shape[-2:],
                                   mode='bilinear')
        support_pred = F.interpolate(support_pred,
                                     size=support_images.shape[-2:],
                                     mode='bilinear')
        support_pred_mc = F.interpolate(support_pred_mc,
                                        size=support_images.shape[-2:],
                                        mode='bilinear')
        query_pred_mc = F.interpolate(query_pred_mc,
                                      size=query_images.shape[-2:],
                                      mode='bilinear')

        binary_loss = criterion(query_pred, query_labels) + criterion(
            support_pred,
            support_fg_mask.long().cuda())
        mcl_loss = criterion(support_pred_mc, support_label_mc) + criterion(
            query_pred_mc, query_label_mc)

        loss = binary_loss + mcl_loss * _config['mcl_loss_scaler']
        loss.backward()

        optimizer.step()
        scheduler.step()

        # Log loss
        binary_loss = binary_loss.detach().data.cpu().numpy()
        mcl_loss = mcl_loss.detach().data.cpu().numpy() if mcl_loss != 0 else 0
        _run.log_scalar('loss', binary_loss)
        _run.log_scalar('mcl_loss', mcl_loss)
        log_loss['loss'] += binary_loss
        log_loss['mcl_loss'] += mcl_loss

        # print loss and take snapshots
        if (i_iter + 1) % _config['print_interval'] == 0:
            loss = log_loss['loss'] / (i_iter + 1)
            mcl_loss = log_loss['mcl_loss'] / (i_iter + 1)
            print(f'step {i_iter+1}: loss: {loss}, mcl_loss: {mcl_loss}')

        if (i_iter + 1) % _config['save_pred_every'] == 0:
            _log.info('###### Taking snapshot ######')
            torch.save(
                model.state_dict(),
                os.path.join(f'{_run.observers[0].dir}/snapshots',
                             f'{i_iter + 1}.pth'))

    _log.info('###### Saving final model ######')
    torch.save(
        model.state_dict(),
        os.path.join(f'{_run.observers[0].dir}/snapshots',
                     f'{i_iter + 1}.pth'))