def test_medical(device, model=None, test_organ='liver'): assert cfg['panet']['backbone'] == 'unet' dataset_name = 'ircadb' if model is None: model = PANetFewShotSeg( in_channels=cfg[dataset_name]['channels'], pretrained_path=None, cfg={ 'align': True }, encoder_type=cfg['panet']['backbone']).to(device) load_state(cfg[dataset_name]['model_name'], model) model.eval() transforms = Compose([ Resize(size=cfg['panet']['unet_inp_size']), ]) test_dataset = get_visceral_medical_few_shot_dataset(organs=[test_organ], patient_ids=range( 16, 20), shots=5, mode='test', transforms=transforms) testloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) metric = Metric(max_label=20, n_runs=1) for i_iter, (support, query) in enumerate(tqdm(testloader, position=0, leave=True)): support_images = [[]] support_fg_mask = [[]] support_bg_mask = [[]] # for i in range(len(support)): support_bg = support[1] == 0 support_fg = support[1] == 1 support_images[0].append(support[0].to(device)) support_fg_mask[0].append(support_fg.float().to(device)) support_bg_mask[0].append(support_bg.float().to(device)) query_images = [] query_labels = [] # for i in range(len(query)): query_images.append(query[0].to(device)) query_labels.append(query[1].to(device)) query_labels = torch.cat(query_labels, dim=0).long().to(device) query_pred, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) print("Support ", i_iter) plt.subplot(1, 2, 1) try: plt.imshow( np.moveaxis(support[0].squeeze().cpu().detach().numpy(), 0, 2)) except np.AxisError: plt.imshow(support[0].squeeze().cpu().detach().numpy()) plt.subplot(1, 2, 2) plt.imshow(support[1].squeeze()) plt.show() print("Query ", i_iter) plt.subplot(1, 3, 1) try: plt.imshow( np.moveaxis(query[0].squeeze().cpu().detach().numpy(), 0, 2)) except np.AxisError: plt.imshow(query[0].squeeze().cpu().detach().numpy()) plt.subplot(1, 3, 2) plt.imshow(query[1].squeeze()) plt.subplot(1, 3, 3) plt.imshow(np.array(query_pred.argmax(dim=1)[0].cpu())) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), n_run=0) plt.show() classIoU, meanIoU = metric.get_mIoU(n_run=0) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=0) print('classIoU', classIoU.tolist()) print('meanIoU', meanIoU.tolist()) print('classIoU_binary', classIoU_binary.tolist()) print('meanIoU_binary', meanIoU_binary.tolist()) print('classIoU: {}'.format(classIoU)) print('meanIoU: {}'.format(meanIoU)) print('classIoU_binary: {}'.format(classIoU_binary)) print('meanIoU_binary: {}'.format(meanIoU_binary))
def test_panet(device, model=None, dataset_name='voc', test_organ='liver'): if model is None: # pretrained_path='../data/vgg16-397923af.pth' model = PANetFewShotSeg( in_channels=cfg[dataset_name]['channels'], pretrained_path=None, cfg={ 'align': True }, encoder_type=cfg['panet']['backbone']).to(device) load_state(cfg[dataset_name]['model_name'], model) model.eval() if dataset_name == 'voc': transforms = Compose([ Resize(size=cfg['panet']['vgg_inp_size']), ]) elif dataset_name == 'ircadb': transforms = Compose([ Resize(size=cfg['panet']['unet_inp_size']), ]) if dataset_name == 'voc': test_dataset = get_pascal_few_shot_datasets( range(16, 21), cfg['panet']['test_iterations'], cfg['nshot'], cfg['nquery'], transforms) elif dataset_name == 'ircadb': test_dataset = get_ircadb_few_shot_datasets( organs=[test_organ], patient_ids=range(16, 21), iterations=cfg['panet']['test_iterations'], N_shot=cfg['nshot'], N_query=cfg['nquery'], transforms=transforms) testloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) metric = Metric(max_label=20, n_runs=1) for i_iter, (support, query) in enumerate(testloader): support_images = [[]] support_fg_mask = [[]] support_bg_mask = [[]] for i in range(len(support)): support_images[0].append(support[i][0].to(device)) support_fg_mask[0].append(support[i][1].to(device)) support_bg_mask[0].append(support[i][2].to(device)) query_images = [] query_labels = [] for i in range(len(query)): query_images.append(query[i][0].to(device)) query_labels.append(query[i][1].to(device)) query_labels = torch.cat(query_labels, dim=0).long().to(device) query_pred, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) print("Support ", i_iter) for i in range(len(support)): plt.subplot(1, 2 * len(support), 2 * i + 1) try: plt.imshow( np.moveaxis(support[i][0].squeeze().cpu().detach().numpy(), 0, 2)) except np.AxisError: plt.imshow(support[i][0].squeeze().cpu().detach().numpy()) plt.subplot(1, 2 * len(support), 2 * i + 2) plt.imshow(support[i][1].squeeze()) plt.show() print("Query ", i_iter) for i in range(len(query)): plt.subplot(1, 3 * len(query), 3 * i + 1) try: plt.imshow( np.moveaxis(query[i][0].squeeze().cpu().detach().numpy(), 0, 2)) except np.AxisError: plt.imshow(query[i][0].squeeze().cpu().detach().numpy()) plt.subplot(1, 3 * len(query), 3 * i + 2) plt.imshow(query[i][1].squeeze()) plt.subplot(1, 3 * len(query), 3 * i + 3) plt.imshow(np.array(query_pred.argmax(dim=1)[i].cpu())) metric.record(np.array(query_pred.argmax(dim=1)[i].cpu()), np.array(query_labels[i].cpu()), n_run=0) plt.show() classIoU, meanIoU = metric.get_mIoU(n_run=0) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=0) print('classIoU', classIoU.tolist()) print('meanIoU', meanIoU.tolist()) print('classIoU_binary', classIoU_binary.tolist()) print('meanIoU_binary', meanIoU_binary.tolist()) print('classIoU: {}'.format(classIoU)) print('meanIoU: {}'.format(meanIoU)) print('classIoU_binary: {}'.format(classIoU_binary)) print('meanIoU_binary: {}'.format(meanIoU_binary))