Ejemplo n.º 1
0
def main(config):
    num = 100000
    palette_path = config['palette_dir']
    with open(palette_path) as f:
        palette = f.readlines()
    palette = list(
        np.asarray([[int(p) for p in pal[0:-1].split(' ')]
                    for pal in palette]).reshape(768))

    n_shots = config['task']['n_shots']
    n_ways = config['task']['n_ways']
    set_seed(config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=config['gpu_id'])
    torch.set_num_threads(1)
    model = FewShotSeg(cfg=config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        config['gpu_id'],
    ])
    if config['train']:
        model.load_state_dict(
            torch.load(config['snapshots'], map_location='cpu'))
    model.eval()

    data_name = config['dataset']
    if data_name == 'davis':
        make_data = davis2017_test
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['val']
    list_label = []
    for i in labels:
        list_label.append(i)
    list_label = sorted(list_label)
    transforms = [Resize(size=config['input_size'])]
    transforms = Compose(transforms)

    with torch.no_grad():
        for run in range(config['n_runs']):
            set_seed(config['seed'] + run)
            dataset = make_data(base_dir=config['path'][data_name]['data_dir'],
                                split=config['path'][data_name]['data_split'],
                                transforms=transforms,
                                to_tensor=ToTensorNormalize(),
                                labels=labels)

            testloader = DataLoader(dataset,
                                    batch_size=config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)

            for iteration, batch in enumerate(testloader):
                class_name = batch[2]
                if not os.path.exists(f'./result/{num}/{class_name[0]}'):
                    os.makedirs(f'./result/{num}/{class_name[0]}')
                num_frame = batch[1]
                all_class_data = batch[0]
                class_ids = all_class_data[0]['obj_ids']
                support_images = [[
                    all_class_data[0]['image'].cuda() for _ in range(n_shots)
                ] for _ in range(n_ways)]
                support_mask = all_class_data[0]['label'][
                    list_label[iteration]]
                support_fg_mask = [[
                    get_fg_mask(support_mask, class_ids[way])
                    for shot in range(n_shots)
                ] for way in range(len(class_ids))]
                support_bg_mask = [[
                    get_bg_mask(support_mask, class_ids)
                    for _ in range(n_shots)
                ] for _ in range(n_ways)]
                s_fg_mask = [[shot['fg_mask'].float().cuda() for shot in way]
                             for way in support_fg_mask]
                s_bg_mask = [[shot['bg_mask'].float().cuda() for shot in way]
                             for way in support_bg_mask]
                # print(f'fg_mask {s_bg_mask[0][0].shape}')
                # print(f'bg_mask {s_bg_mask[0][0].shape}')
                # print(support_mask.shape)

                for idx, data in enumerate(all_class_data):
                    query_images = [
                        all_class_data[idx]['image'].cuda()
                        for i in range(n_ways)
                    ]
                    query_labels = torch.cat([
                        query_label.cuda() for query_label in [
                            all_class_data[idx]['label'][
                                list_label[iteration]],
                        ]
                    ])

                    # print(f'query_image{query_images[0].shape}')
                    if idx > 0:
                        pre_mask = [
                            pred_mask,
                        ]
                    elif idx == 0:
                        pre_mask = [
                            support_mask.float().cuda(),
                        ]
                    query_pred, _ = model(support_images, s_fg_mask, s_bg_mask,
                                          query_images, pre_mask)
                    pred = query_pred.argmax(dim=1, keepdim=True)
                    pred = pred.data.cpu().numpy()
                    img = pred[0, 0]
                    img_e = Image.fromarray(img.astype('float32')).convert('P')
                    pred_mask = tr_F.resize(img_e,
                                            config['input_size'],
                                            interpolation=Image.NEAREST)
                    pred_mask = torch.Tensor(np.array(pred_mask))
                    pred_mask = torch.unsqueeze(pred_mask, dim=0)
                    pred_mask = pred_mask.float().cuda()
                    img_e.putpalette(palette)
                    # print(os.path.join(f'./result/{class_name[0]}/', '{:05d}.png'.format(idx)))
                    # print(batch[3][idx])
                    img_e.save(
                        os.path.join(f'./result/{num}/{class_name[0]}/',
                                     '{:05d}.png'.format(idx)))
Ejemplo n.º 2
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                    exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    if not _config['notrain']:
        model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()

    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 6
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    if _config['scribble_dilation'] > 0:
        transforms.append(DilateScribble(size=_config['scribble_dilation']))
    transforms = Compose(transforms)

    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries']
            )
            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False,
                                    num_workers=1, pin_memory=True, drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")

            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']]
                else:
                    label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'scribble' if _config['scribble'] else 'mask'

                if _config['bbox']:
                    support_fg_mask = []
                    support_bg_mask = []
                    for i, way in enumerate(sample_batched['support_mask']):
                        fg_masks = []
                        bg_masks = []
                        for j, shot in enumerate(way):
                            fg_mask, bg_mask = get_bbox(shot['fg_mask'],
                                                        sample_batched['support_inst'][i][j])
                            fg_masks.append(fg_mask.float().cuda())
                            bg_masks.append(bg_mask.float().cuda())
                        support_fg_mask.append(fg_masks)
                        support_bg_mask.append(bg_masks)
                else:
                    support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way]
                                       for way in sample_batched['support_mask']]
                    support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way]
                                       for way in sample_batched['support_mask']]

                query_images = [query_image.cuda()
                                for query_image in sample_batched['query_images']]
                query_labels = torch.cat(
                    [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0)

                query_pred, _ = model(support_images, support_fg_mask, support_bg_mask,
                                      query_images)
                # # visual
                # mean = [0.485, 0.456, 0.406]
                # std = [0.229, 0.224, 0.225]
                # pred = np.array(query_pred.argmax(dim=1)[0].cpu())
                # prediction = pred.transpose()
                # q_img = query_images[0].cpu()[0]
                # qu_img = np.array(q_img.permute(2, 1, 0))
                # # qu_img = np.array(torch.transpose(q_img, 0, 2))
                # que_img = (qu_img * std + mean) * 255
                # que_img = que_img.reshape(512, 512, 3).astype(np.uint8)
                # # plt.imshow(que_img)
                # # plt.show()
                # blend_image_label = color_map.blend_img_colorlabel(que_img, prediction)
                # blend = np.asarray(blend_image_label)
                # cv2.namedWindow("Zhanfen", 0)
                # cv2.resizeWindow("Zhanfen", 512, 512)
                # cv2.imshow("Zhanfen", blend)
                # cv2.waitKey(0)

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids, n_run=run)

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary()

    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
Ejemplo n.º 3
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    if not _config['notrain']:
        model.load_state_dict(
            torch.load(_config['snapshot'], map_location='cpu'))
    print("Snapshotttt")
    print(_config['snapshot'])
    model.eval()

    # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2net' + '/' + 'u2net_bce_itr_3168_train_1.523160_tar_0.203136.pth'
    # u2_net = U2NET(3,1)
    # u2_net.load_state_dict(torch.load(u2_model_dir))

    # if torch.cuda.is_available():
    #     u2_net.cuda()

    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')
    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][
        _config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    if _config['scribble_dilation'] > 0:
        transforms.append(DilateScribble(size=_config['scribble_dilation']))
    transforms = Compose(transforms)

    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries'])
            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()
            testloader = DataLoader(dataset,
                                    batch_size=_config['batch_size'],
                                    shuffle=False,
                                    num_workers=1,
                                    pin_memory=True,
                                    drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")

            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [
                        coco_cls_ids.index(x) + 1
                        for x in sample_batched['class_ids']
                    ]
                else:
                    label_ids = list(sample_batched['class_ids'])
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]
                suffix = 'scribble' if _config['scribble'] else 'mask'

                if _config['bbox']:
                    support_fg_mask = []
                    support_bg_mask = []
                    for i, way in enumerate(sample_batched['support_mask']):
                        fg_masks = []
                        bg_masks = []
                        for j, shot in enumerate(way):
                            fg_mask, bg_mask = get_bbox(
                                shot['fg_mask'],
                                sample_batched['support_inst'][i][j])
                            fg_masks.append(fg_mask.float().cuda())
                            bg_masks.append(bg_mask.float().cuda())
                        support_fg_mask.append(fg_masks)
                        support_bg_mask.append(bg_masks)
                else:
                    support_fg_mask = [[
                        shot[f'fg_{suffix}'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]
                    support_bg_mask = [[
                        shot[f'bg_{suffix}'].float().cuda() for shot in way
                    ] for way in sample_batched['support_mask']]

                query_images = [
                    query_image.cuda()
                    for query_image in sample_batched['query_images']
                ]
                query_labels = torch.cat([
                    query_label.cuda()
                    for query_label in sample_batched['query_labels']
                ],
                                         dim=0)

                # u2net
                inputs = query_images[0].type(torch.FloatTensor)
                labels_v = query_labels.type(torch.FloatTensor)
                if torch.cuda.is_available():
                    inputs_v, labels_v = Variable(
                        inputs.cuda(),
                        requires_grad=False), Variable(labels_v.cuda(),
                                                       requires_grad=False)
                else:
                    inputs_v, labels_v = Variable(
                        inputs,
                        requires_grad=False), Variable(labels_v,
                                                       requires_grad=False)
                #d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v)

                # normalization
                # pred = d1[:,0,:,:]
                # pred = normPRED(pred)
                pred = []

                query_pred, _, _ = model(support_images, support_fg_mask,
                                         support_bg_mask, query_images, pred)
                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids,
                              n_run=run)

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels),
                                                n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(
        labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary(
    )

    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
Ejemplo n.º 4
0
def main(_run, _config, _log):
    os.makedirs(f'{_run.observers[0].dir}/features', exist_ok=True)
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                    exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)


    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'], task=_config['task'])
    model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],])
    if not _config['notrain']:
        model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()


    _log.info('###### Prepare data ######')
    data_name = _config['dataset']
    if data_name == 'VOC':
        make_data = voc_fewshot
        max_label = 20
    elif data_name == 'COCO':
        make_data = coco_fewshot
        max_label = 80
    else:
        raise ValueError('Wrong config for dataset!')

    labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']]
    transforms = [Resize(size=_config['input_size'])]
    transforms = Compose(transforms)


    _log.info('###### Testing begins ######')
    metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    with torch.no_grad():
        for run in range(_config['n_runs']):
            _log.info(f'### Run {run + 1} ###')
            set_seed(_config['seed'] + run)
            features_dfs = []

            _log.info(f'### Load data ###')
            dataset = make_data(
                base_dir=_config['path'][data_name]['data_dir'],
                split=_config['path'][data_name]['data_split'],
                transforms=transforms,
                to_tensor=ToTensorNormalize(),
                labels=labels,
                max_iters=_config['n_steps'] * _config['batch_size'],
                n_ways=_config['task']['n_ways'],
                n_shots=_config['task']['n_shots'],
                n_queries=_config['task']['n_queries']
            )

            if _config['dataset'] == 'COCO':
                coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds()

            testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False,
                                    num_workers=1, pin_memory=True, drop_last=False)
            _log.info(f"Total # of Data: {len(dataset)}")


            for sample_batched in tqdm.tqdm(testloader):
                if _config['dataset'] == 'COCO':
                    label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']]
                else:
                    label_ids = list(sample_batched['class_ids'])
                
                support_ids = [[sample_batched['support_ids'][way*_config['task']['n_shots'] + shot][0] 
                                for shot in range(_config['task']['n_shots'])]
                                for way in range(_config['task']['n_ways'])]
                support_images = [[shot.cuda() for shot in way]
                                  for way in sample_batched['support_images']]

                support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way]
                                    for way in sample_batched['support_mask']]
                support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way]
                                    for way in sample_batched['support_mask']]

                query_images = [query_image.cuda()
                                for query_image in sample_batched['query_images']]
                query_labels = torch.cat(
                    [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0)

                query_pred, _, supp_fts = model(support_images, support_fg_mask, support_bg_mask,
                                      query_images)

                metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()),
                              np.array(query_labels[0].cpu()),
                              labels=label_ids, n_run=run)
                
                # Save features row
                for i, label_id in enumerate(label_ids):
                    lbl_df = pd.DataFrame(torch.cat(supp_fts[i]).cpu().numpy())
                    lbl_df['label'] = label_id.item()
                    lbl_df['id'] = pd.Series(support_ids[i])
                    features_dfs.append(lbl_df)
                

            classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run)
            classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run)

            _run.log_scalar('classIoU', classIoU.tolist())
            _run.log_scalar('meanIoU', meanIoU.tolist())
            _run.log_scalar('classIoU_binary', classIoU_binary.tolist())
            _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist())
            _log.info(f'classIoU: {classIoU}')
            _log.info(f'meanIoU: {meanIoU}')
            _log.info(f'classIoU_binary: {classIoU_binary}')
            _log.info(f'meanIoU_binary: {meanIoU_binary}')

            _log.info('### Exporting features CSV')
            features_df = pd.concat(features_dfs)
            features_df = features_df.drop_duplicates(subset=['id'])
            cols = list(features_df)
            cols = [cols[-1], cols[-2]] + cols[:-2]
            features_df = features_df[cols]
            features_df.to_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv', index=False)

    classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels))
    classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary()

    _log.info('###### Saving features visualization ######')
    all_fts = pd.concat([pd.read_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv') for run in range(_config['n_runs'])])
    all_fts = all_fts.drop_duplicates(subset=['id'])

    _log.info('### Obtaining Umap visualization ###')
    plot_umap(all_fts, f'{_run.observers[0].dir}/features/Umap_fts.png')

    _log.info('### Obtaining TSNE visualization ###')
    plot_tsne(all_fts, f'{_run.observers[0].dir}/features/TSNE_fts.png')


    _log.info('----- Final Result -----')
    _run.log_scalar('final_classIoU', classIoU.tolist())
    _run.log_scalar('final_classIoU_std', classIoU_std.tolist())
    _run.log_scalar('final_meanIoU', meanIoU.tolist())
    _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist())
    _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist())
    _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist())
    _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist())
    _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist())
    _log.info(f'classIoU mean: {classIoU}')
    _log.info(f'classIoU std: {classIoU_std}')
    _log.info(f'meanIoU mean: {meanIoU}')
    _log.info(f'meanIoU std: {meanIoU_std}')
    _log.info(f'classIoU_binary mean: {classIoU_binary}')
    _log.info(f'classIoU_binary std: {classIoU_std_binary}')
    _log.info(f'meanIoU_binary mean: {meanIoU_binary}')
    _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
Ejemplo n.º 5
0
def main(_run, _config, _log):
    for source_file, _ in _run.experiment_info['sources']:
        os.makedirs(
            os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
            exist_ok=True)
        _run.observers[0].save_file(source_file, f'source/{source_file}')
    shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Create model ######')
    model = FewShotSeg(pretrained_path=_config['path']['init_path'],
                       cfg=_config['model'])
    model = nn.DataParallel(model.cuda(), device_ids=[
        _config['gpu_id'],
    ])
    if not _config['notrain']:
        model.load_state_dict(
            torch.load(_config['snapshot'], map_location='cpu'))
    model.eval()

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    make_data = meta_data
    max_label = 1

    tr_dataset, val_dataset, ts_dataset = make_data(_config)
    testloader = DataLoader(
        dataset=ts_dataset,
        batch_size=1,
        shuffle=False,
        # num_workers=_config['n_work'],
        pin_memory=False,  # True
        drop_last=False)

    if _config['record']:
        _log.info('###### define tensorboard writer #####')
        board_name = f'board/test_{_config["board"]}_{date()}'
        writer = SummaryWriter(board_name)

    _log.info('###### Testing begins ######')
    # metric = Metric(max_label=max_label, n_runs=_config['n_runs'])
    img_cnt = 0
    # length = len(all_samples)
    length = len(testloader)
    img_lists = []
    pred_lists = []
    label_lists = []

    saves = {}
    for subj_idx in range(len(ts_dataset.get_cnts())):
        saves[subj_idx] = []

    with torch.no_grad():
        loss_valid = 0
        batch_i = 0  # use only 1 batch size for testing

        for i, sample_test in enumerate(
                testloader):  # even for upward, down for downward
            subj_idx, idx = ts_dataset.get_test_subj_idx(i)
            img_list = []
            pred_list = []
            label_list = []
            preds = []
            fnames = sample_test['q_fname']

            s_x_orig = sample_test['s_x'].cuda(
            )  # [B, Support, slice_num=1, 1, 256, 256]
            s_x = s_x_orig.squeeze(2)  # [B, Support, 1, 256, 256]
            s_y_fg_orig = sample_test['s_y'].cuda(
            )  # [B, Support, slice_num, 1, 256, 256]
            s_y_fg = s_y_fg_orig.squeeze(2)  # [B, Support, 1, 256, 256]
            s_y_fg = s_y_fg.squeeze(2)  # [B, Support, 256, 256]
            s_y_bg = torch.ones_like(s_y_fg) - s_y_fg
            q_x_orig = sample_test['q_x'].cuda()  # [B, slice_num, 1, 256, 256]
            q_x = q_x_orig.squeeze(1)  # [B, 1, 256, 256]
            q_y_orig = sample_test['q_y'].cuda()  # [B, slice_num, 1, 256, 256]
            q_y = q_y_orig.squeeze(1)  # [B, 1, 256, 256]
            q_y = q_y.squeeze(1).long()  # [B, 256, 256]
            s_xs = [[s_x[:, shot, ...] for shot in range(_config["n_shot"])]]
            s_y_fgs = [[
                s_y_fg[:, shot, ...] for shot in range(_config["n_shot"])
            ]]
            s_y_bgs = [[
                s_y_bg[:, shot, ...] for shot in range(_config["n_shot"])
            ]]
            q_xs = [q_x]
            q_yhat, align_loss = model(s_xs, s_y_fgs, s_y_bgs, q_xs)
            # q_yhat = q_yhat[:,1:2, ...]
            q_yhat = q_yhat.argmax(dim=1)
            q_yhat = q_yhat.unsqueeze(1)

            preds.append(q_yhat)
            img_list.append(q_x_orig[batch_i, 0].cpu().numpy())
            pred_list.append(q_yhat[batch_i].cpu().numpy())
            label_list.append(q_y_orig[batch_i, 0].cpu().numpy())

            saves[subj_idx].append(
                [subj_idx, idx, img_list, pred_list, label_list, fnames])
            print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r')
            img_lists.append(img_list)
            pred_lists.append(pred_list)
            label_lists.append(label_list)

    print("start computing dice similarities ... total ", len(saves))
    dice_similarities = []
    for subj_idx in range(len(saves)):
        imgs, preds, labels = [], [], []
        save_subj = saves[subj_idx]
        for i in range(len(save_subj)):
            # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i)
            subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[
                i]
            # print(subj_idx, idx, is_reverse, len(img_list))
            # print(i, is_reverse, is_reverse_next, is_flip)

            for j in range(len(img_list)):
                imgs.append(img_list[j])
                preds.append(pred_list[j])
                labels.append(label_list[j])

        # pdb.set_trace()
        img_arr = np.concatenate(imgs, axis=0)
        pred_arr = np.concatenate(preds, axis=0)
        label_arr = np.concatenate(labels, axis=0)
        # pdb.set_trace()
        # print(ts_dataset.slice_cnts[subj_idx] , len(imgs))
        # pdb.set_trace()
        dice = np.sum([label_arr * pred_arr
                       ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr))
        dice_similarities.append(dice)
        print(f"computing dice scores {subj_idx}/{10}", end='\n')

        if _config['record']:
            frames = []
            for frame_id in range(0, len(save_subj)):
                frames += overlay_color(torch.tensor(imgs[frame_id]),
                                        torch.tensor(preds[frame_id]).float(),
                                        torch.tensor(labels[frame_id]))
            visual = make_grid(frames, normalize=True, nrow=5)
            writer.add_image(f"test/{subj_idx}", visual, i)
            writer.add_scalar(f'dice_score/{i}', dice)

        if _config['save_sample']:
            ## only for internal test (BCV - MICCAI2015)
            sup_idx = _config['s_idx']
            target = _config['target']
            save_name = _config['save_name']
            dirs = ["gt", "pred", "input"]
            save_dir = f"../sample/panet_organ{target}_sup{sup_idx}_{save_name}"

            for dir in dirs:
                try:
                    os.makedirs(os.path.join(save_dir, dir))
                except:
                    pass

            subj_name = fnames[0][0].split("/")[-2]
            if target == 14:
                src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img"
                orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz"
                pass
            else:
                src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img"
                orig_fname = f"{src_dir}/img{subj_name}.nii.gz"

            itk = sitk.ReadImage(orig_fname)
            orig_spacing = itk.GetSpacing()

            label_arr = label_arr * 2.0
            # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])])
            # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])])
            # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])])
            itk = sitk.GetImageFromArray(label_arr)
            itk.SetSpacing(orig_spacing)
            sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz")
            itk = sitk.GetImageFromArray(pred_arr.astype(float))
            itk.SetSpacing(orig_spacing)
            sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz")
            itk = sitk.GetImageFromArray(img_arr)
            itk.SetSpacing(orig_spacing)
            sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz")

    print(f"test result \n n : {len(dice_similarities)}, mean dice score : \
    {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}")

    if _config['record']:
        writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))