def main(config): num = 100000 palette_path = config['palette_dir'] with open(palette_path) as f: palette = f.readlines() palette = list( np.asarray([[int(p) for p in pal[0:-1].split(' ')] for pal in palette]).reshape(768)) n_shots = config['task']['n_shots'] n_ways = config['task']['n_ways'] set_seed(config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=config['gpu_id']) torch.set_num_threads(1) model = FewShotSeg(cfg=config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ config['gpu_id'], ]) if config['train']: model.load_state_dict( torch.load(config['snapshots'], map_location='cpu')) model.eval() data_name = config['dataset'] if data_name == 'davis': make_data = davis2017_test else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['val'] list_label = [] for i in labels: list_label.append(i) list_label = sorted(list_label) transforms = [Resize(size=config['input_size'])] transforms = Compose(transforms) with torch.no_grad(): for run in range(config['n_runs']): set_seed(config['seed'] + run) dataset = make_data(base_dir=config['path'][data_name]['data_dir'], split=config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels) testloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) for iteration, batch in enumerate(testloader): class_name = batch[2] if not os.path.exists(f'./result/{num}/{class_name[0]}'): os.makedirs(f'./result/{num}/{class_name[0]}') num_frame = batch[1] all_class_data = batch[0] class_ids = all_class_data[0]['obj_ids'] support_images = [[ all_class_data[0]['image'].cuda() for _ in range(n_shots) ] for _ in range(n_ways)] support_mask = all_class_data[0]['label'][ list_label[iteration]] support_fg_mask = [[ get_fg_mask(support_mask, class_ids[way]) for shot in range(n_shots) ] for way in range(len(class_ids))] support_bg_mask = [[ get_bg_mask(support_mask, class_ids) for _ in range(n_shots) ] for _ in range(n_ways)] s_fg_mask = [[shot['fg_mask'].float().cuda() for shot in way] for way in support_fg_mask] s_bg_mask = [[shot['bg_mask'].float().cuda() for shot in way] for way in support_bg_mask] # print(f'fg_mask {s_bg_mask[0][0].shape}') # print(f'bg_mask {s_bg_mask[0][0].shape}') # print(support_mask.shape) for idx, data in enumerate(all_class_data): query_images = [ all_class_data[idx]['image'].cuda() for i in range(n_ways) ] query_labels = torch.cat([ query_label.cuda() for query_label in [ all_class_data[idx]['label'][ list_label[iteration]], ] ]) # print(f'query_image{query_images[0].shape}') if idx > 0: pre_mask = [ pred_mask, ] elif idx == 0: pre_mask = [ support_mask.float().cuda(), ] query_pred, _ = model(support_images, s_fg_mask, s_bg_mask, query_images, pre_mask) pred = query_pred.argmax(dim=1, keepdim=True) pred = pred.data.cpu().numpy() img = pred[0, 0] img_e = Image.fromarray(img.astype('float32')).convert('P') pred_mask = tr_F.resize(img_e, config['input_size'], interpolation=Image.NEAREST) pred_mask = torch.Tensor(np.array(pred_mask)) pred_mask = torch.unsqueeze(pred_mask, dim=0) pred_mask = pred_mask.float().cuda() img_e.putpalette(palette) # print(os.path.join(f'./result/{class_name[0]}/', '{:05d}.png'.format(idx))) # print(batch[3][idx]) img_e.save( os.path.join(f'./result/{num}/{class_name[0]}/', '{:05d}.png'.format(idx)))
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],]) if not _config['notrain']: model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 6 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']] transforms = [Resize(size=_config['input_size'])] if _config['scribble_dilation'] > 0: transforms.append(DilateScribble(size=_config['scribble_dilation'])) transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'] ) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] suffix = 'scribble' if _config['scribble'] else 'mask' if _config['bbox']: support_fg_mask = [] support_bg_mask = [] for i, way in enumerate(sample_batched['support_mask']): fg_masks = [] bg_masks = [] for j, shot in enumerate(way): fg_mask, bg_mask = get_bbox(shot['fg_mask'], sample_batched['support_inst'][i][j]) fg_masks.append(fg_mask.float().cuda()) bg_masks.append(bg_mask.float().cuda()) support_fg_mask.append(fg_masks) support_bg_mask.append(bg_masks) else: support_fg_mask = [[shot[f'fg_{suffix}'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_{suffix}'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [query_image.cuda() for query_image in sample_batched['query_images']] query_labels = torch.cat( [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0) query_pred, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) # # visual # mean = [0.485, 0.456, 0.406] # std = [0.229, 0.224, 0.225] # pred = np.array(query_pred.argmax(dim=1)[0].cpu()) # prediction = pred.transpose() # q_img = query_images[0].cpu()[0] # qu_img = np.array(q_img.permute(2, 1, 0)) # # qu_img = np.array(torch.transpose(q_img, 0, 2)) # que_img = (qu_img * std + mean) * 255 # que_img = que_img.reshape(512, 512, 3).astype(np.uint8) # # plt.imshow(que_img) # # plt.show() # blend_image_label = color_map.blend_img_colorlabel(que_img, prediction) # blend = np.asarray(blend_image_label) # cv2.namedWindow("Zhanfen", 0) # cv2.resizeWindow("Zhanfen", 512, 512) # cv2.imshow("Zhanfen", blend) # cv2.waitKey(0) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) print("Snapshotttt") print(_config['snapshot']) model.eval() # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2net' + '/' + 'u2net_bce_itr_3168_train_1.523160_tar_0.203136.pth' # u2_net = U2NET(3,1) # u2_net.load_state_dict(torch.load(u2_model_dir)) # if torch.cuda.is_available(): # u2_net.cuda() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] if _config['scribble_dilation'] > 0: transforms.append(DilateScribble(size=_config['scribble_dilation'])) transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] suffix = 'scribble' if _config['scribble'] else 'mask' if _config['bbox']: support_fg_mask = [] support_bg_mask = [] for i, way in enumerate(sample_batched['support_mask']): fg_masks = [] bg_masks = [] for j, shot in enumerate(way): fg_mask, bg_mask = get_bbox( shot['fg_mask'], sample_batched['support_inst'][i][j]) fg_masks.append(fg_mask.float().cuda()) bg_masks.append(bg_mask.float().cuda()) support_fg_mask.append(fg_masks) support_bg_mask.append(bg_masks) else: support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) # u2net inputs = query_images[0].type(torch.FloatTensor) labels_v = query_labels.type(torch.FloatTensor) if torch.cuda.is_available(): inputs_v, labels_v = Variable( inputs.cuda(), requires_grad=False), Variable(labels_v.cuda(), requires_grad=False) else: inputs_v, labels_v = Variable( inputs, requires_grad=False), Variable(labels_v, requires_grad=False) #d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v) # normalization # pred = d1[:,0,:,:] # pred = normPRED(pred) pred = [] query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images, pred) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(_run, _config, _log): os.makedirs(f'{_run.observers[0].dir}/features', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'], task=_config['task']) model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],]) if not _config['notrain']: model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) features_dfs = [] _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'] ) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) support_ids = [[sample_batched['support_ids'][way*_config['task']['n_shots'] + shot][0] for shot in range(_config['task']['n_shots'])] for way in range(_config['task']['n_ways'])] support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [query_image.cuda() for query_image in sample_batched['query_images']] query_labels = torch.cat( [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0) query_pred, _, supp_fts = model(support_images, support_fg_mask, support_bg_mask, query_images) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) # Save features row for i, label_id in enumerate(label_ids): lbl_df = pd.DataFrame(torch.cat(supp_fts[i]).cpu().numpy()) lbl_df['label'] = label_id.item() lbl_df['id'] = pd.Series(support_ids[i]) features_dfs.append(lbl_df) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') _log.info('### Exporting features CSV') features_df = pd.concat(features_dfs) features_df = features_df.drop_duplicates(subset=['id']) cols = list(features_df) cols = [cols[-1], cols[-2]] + cols[:-2] features_df = features_df[cols] features_df.to_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv', index=False) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() _log.info('###### Saving features visualization ######') all_fts = pd.concat([pd.read_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv') for run in range(_config['n_runs'])]) all_fts = all_fts.drop_duplicates(subset=['id']) _log.info('### Obtaining Umap visualization ###') plot_umap(all_fts, f'{_run.observers[0].dir}/features/Umap_fts.png') _log.info('### Obtaining TSNE visualization ###') plot_tsne(all_fts, f'{_run.observers[0].dir}/features/TSNE_fts.png') _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False) if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] fnames = sample_test['q_fname'] s_x_orig = sample_test['s_x'].cuda( ) # [B, Support, slice_num=1, 1, 256, 256] s_x = s_x_orig.squeeze(2) # [B, Support, 1, 256, 256] s_y_fg_orig = sample_test['s_y'].cuda( ) # [B, Support, slice_num, 1, 256, 256] s_y_fg = s_y_fg_orig.squeeze(2) # [B, Support, 1, 256, 256] s_y_fg = s_y_fg.squeeze(2) # [B, Support, 256, 256] s_y_bg = torch.ones_like(s_y_fg) - s_y_fg q_x_orig = sample_test['q_x'].cuda() # [B, slice_num, 1, 256, 256] q_x = q_x_orig.squeeze(1) # [B, 1, 256, 256] q_y_orig = sample_test['q_y'].cuda() # [B, slice_num, 1, 256, 256] q_y = q_y_orig.squeeze(1) # [B, 1, 256, 256] q_y = q_y.squeeze(1).long() # [B, 256, 256] s_xs = [[s_x[:, shot, ...] for shot in range(_config["n_shot"])]] s_y_fgs = [[ s_y_fg[:, shot, ...] for shot in range(_config["n_shot"]) ]] s_y_bgs = [[ s_y_bg[:, shot, ...] for shot in range(_config["n_shot"]) ]] q_xs = [q_x] q_yhat, align_loss = model(s_xs, s_y_fgs, s_y_bgs, q_xs) # q_yhat = q_yhat[:,1:2, ...] q_yhat = q_yhat.argmax(dim=1) q_yhat = q_yhat.unsqueeze(1) preds.append(q_yhat) img_list.append(q_x_orig[batch_i, 0].cpu().numpy()) pred_list.append(q_yhat[batch_i].cpu().numpy()) label_list.append(q_y_orig[batch_i, 0].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list, fnames]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) print("start computing dice similarities ... total ", len(saves)) dice_similarities = [] for subj_idx in range(len(saves)): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i) subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[ i] # print(subj_idx, idx, is_reverse, len(img_list)) # print(i, is_reverse, is_reverse_next, is_flip) for j in range(len(img_list)): imgs.append(img_list[j]) preds.append(pred_list[j]) labels.append(label_list[j]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) # pdb.set_trace() # print(ts_dataset.slice_cnts[subj_idx] , len(imgs)) # pdb.set_trace() dice = np.sum([label_arr * pred_arr ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) dice_similarities.append(dice) print(f"computing dice scores {subj_idx}/{10}", end='\n') if _config['record']: frames = [] for frame_id in range(0, len(save_subj)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]).float(), torch.tensor(labels[frame_id])) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) if _config['save_sample']: ## only for internal test (BCV - MICCAI2015) sup_idx = _config['s_idx'] target = _config['target'] save_name = _config['save_name'] dirs = ["gt", "pred", "input"] save_dir = f"../sample/panet_organ{target}_sup{sup_idx}_{save_name}" for dir in dirs: try: os.makedirs(os.path.join(save_dir, dir)) except: pass subj_name = fnames[0][0].split("/")[-2] if target == 14: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img" orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz" pass else: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img" orig_fname = f"{src_dir}/img{subj_name}.nii.gz" itk = sitk.ReadImage(orig_fname) orig_spacing = itk.GetSpacing() label_arr = label_arr * 2.0 # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])]) # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])]) # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])]) itk = sitk.GetImageFromArray(label_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(pred_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(img_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz") print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))