def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) settings = Settings() common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], settings[ 'NETWORK'], settings['TRAINING'], settings['EVAL'] _log.info('###### Create model ######') model = fs.FewShotSegmentorDoubleSDnet(net_params).cuda() # model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],]) if not _config['notrain']: model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False ) if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate(testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] s_x = sample_test['s_x'].cuda() # [B, Support, slice_num=1, 1, 256, 256] X = s_x.squeeze(2) # [B, Support, 1, 256, 256] s_y = sample_test['s_y'].cuda() # [B, Support, slice_num, 1, 256, 256] Y = s_y.squeeze(2) # [B, Support, 1, 256, 256] Y = Y.squeeze(2) # [B, Support, 256, 256] q_x = sample_test['q_x'].cuda() # [B, slice_num, 1, 256, 256] query_input = q_x.squeeze(1) # [B, 1, 256, 256] q_y = sample_test['q_y'].cuda() # [B, slice_num, 1, 256, 256] y2 = q_y.squeeze(1) # [B, 1, 256, 256] y2 = y2.squeeze(1) # [B, 256, 256] y2 = y2.type(torch.LongTensor).cuda() entire_weights = [] for shot_id in range(_config["n_shot"]): input1 = X[:, shot_id, ...] # use 1 shot at first y1 = Y[:, shot_id, ...] # use 1 shot at first condition_input = torch.cat((input1, y1.unsqueeze(1)), dim=1) weights = model.conditioner(condition_input) # 2, 10, [B, channel=1, w, h] entire_weights.append(weights) # pdb.set_trace() avg_weights=[[],[None, None, None, None]] for k in range(9): weight_cat = torch.cat([weights[0][k] for weights in entire_weights],dim=1) avg_weight = torch.mean(weight_cat,dim=1,keepdim=True) avg_weights[0].append(avg_weight) avg_weights[0].append(None) output = model.segmentor(query_input, avg_weights) q_yhat = output.argmax(dim=1) q_yhat = q_yhat.unsqueeze(1) preds.append(q_yhat) img_list.append(q_x[batch_i,0].cpu().numpy()) pred_list.append(q_yhat[batch_i].cpu().numpy()) label_list.append(q_y[batch_i,0].cpu().numpy()) saves[subj_idx].append([subj_idx, idx, img_list, pred_list, label_list]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) print("start computing dice similarities ... total ", len(saves)) dice_similarities = [] for subj_idx in range(len(saves)): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i) subj_idx, idx, img_list, pred_list, label_list = save_subj[i] # print(subj_idx, idx, is_reverse, len(img_list)) # print(i, is_reverse, is_reverse_next, is_flip) for j in range(len(img_list)): imgs.append(img_list[j]) preds.append(pred_list[j]) labels.append(label_list[j]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) # pdb.set_trace() # print(ts_dataset.slice_cnts[subj_idx] , len(imgs)) # pdb.set_trace() dice = np.sum([label_arr * pred_arr]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) dice_similarities.append(dice) print(f"computing dice scores {subj_idx}/{10}", end='\n') if _config['record']: frames = [] for frame_id in range(0, len(save_subj)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]).float(), torch.tensor(labels[frame_id])) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))
def main(config): num = 100000 palette_path = config['palette_dir'] with open(palette_path) as f: palette = f.readlines() palette = list( np.asarray([[int(p) for p in pal[0:-1].split(' ')] for pal in palette]).reshape(768)) n_shots = config['task']['n_shots'] n_ways = config['task']['n_ways'] set_seed(config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=config['gpu_id']) torch.set_num_threads(1) model = FewShotSeg(cfg=config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ config['gpu_id'], ]) if config['train']: model.load_state_dict( torch.load(config['snapshots'], map_location='cpu')) model.eval() data_name = config['dataset'] if data_name == 'davis': make_data = davis2017_test else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['val'] list_label = [] for i in labels: list_label.append(i) list_label = sorted(list_label) transforms = [Resize(size=config['input_size'])] transforms = Compose(transforms) with torch.no_grad(): for run in range(config['n_runs']): set_seed(config['seed'] + run) dataset = make_data(base_dir=config['path'][data_name]['data_dir'], split=config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels) testloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) for iteration, batch in enumerate(testloader): class_name = batch[2] if not os.path.exists(f'./result/{num}/{class_name[0]}'): os.makedirs(f'./result/{num}/{class_name[0]}') num_frame = batch[1] all_class_data = batch[0] class_ids = all_class_data[0]['obj_ids'] support_images = [[ all_class_data[0]['image'].cuda() for _ in range(n_shots) ] for _ in range(n_ways)] support_mask = all_class_data[0]['label'][ list_label[iteration]] support_fg_mask = [[ get_fg_mask(support_mask, class_ids[way]) for shot in range(n_shots) ] for way in range(len(class_ids))] support_bg_mask = [[ get_bg_mask(support_mask, class_ids) for _ in range(n_shots) ] for _ in range(n_ways)] s_fg_mask = [[shot['fg_mask'].float().cuda() for shot in way] for way in support_fg_mask] s_bg_mask = [[shot['bg_mask'].float().cuda() for shot in way] for way in support_bg_mask] # print(f'fg_mask {s_bg_mask[0][0].shape}') # print(f'bg_mask {s_bg_mask[0][0].shape}') # print(support_mask.shape) for idx, data in enumerate(all_class_data): query_images = [ all_class_data[idx]['image'].cuda() for i in range(n_ways) ] query_labels = torch.cat([ query_label.cuda() for query_label in [ all_class_data[idx]['label'][ list_label[iteration]], ] ]) # print(f'query_image{query_images[0].shape}') if idx > 0: pre_mask = [ pred_mask, ] elif idx == 0: pre_mask = [ support_mask.float().cuda(), ] query_pred, _ = model(support_images, s_fg_mask, s_bg_mask, query_images, pre_mask) pred = query_pred.argmax(dim=1, keepdim=True) pred = pred.data.cpu().numpy() img = pred[0, 0] img_e = Image.fromarray(img.astype('float32')).convert('P') pred_mask = tr_F.resize(img_e, config['input_size'], interpolation=Image.NEAREST) pred_mask = torch.Tensor(np.array(pred_mask)) pred_mask = torch.unsqueeze(pred_mask, dim=0) pred_mask = pred_mask.float().cuda() img_e.putpalette(palette) # print(os.path.join(f'./result/{class_name[0]}/', '{:05d}.png'.format(idx))) # print(batch[3][idx]) img_e.save( os.path.join(f'./result/{num}/{class_name[0]}/', '{:05d}.png'.format(idx)))
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) os.makedirs(f'{_run.observers[0].dir}/val_logs', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(device=torch.device("cuda:2"), n_grid=8, overlap=True, overlap_out='max') model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) model.train() _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'Visceral': make_data = visceral_fewshot else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][_config['label_sets']] #labels=set([1,2]) val_labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = Compose([ Resize(size=_config['input_size']), #RandomAffine(), RandomBrightness(), RandomContrast(), RandomGamma() ]) dataset = make_data(base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) trainloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=1, pin_memory=True, drop_last=True) support_volumes = ['10000132_1_CTce_ThAb.nii.gz'] query_volumes = [ '10000100_1_CTce_ThAb.nii.gz', '10000104_1_CTce_ThAb.nii.gz', '10000105_1_CTce_ThAb.nii.gz', '10000106_1_CTce_ThAb.nii.gz', '10000108_1_CTce_ThAb.nii.gz', '10000109_1_CTce_ThAb.nii.gz', '10000110_1_CTce_ThAb.nii.gz', '10000111_1_CTce_ThAb.nii.gz', '10000112_1_CTce_ThAb.nii.gz', '10000113_1_CTce_ThAb.nii.gz', '10000127_1_CTce_ThAb.nii.gz', '10000128_1_CTce_ThAb.nii.gz', '10000129_1_CTce_ThAb.nii.gz', '10000130_1_CTce_ThAb.nii.gz', '10000131_1_CTce_ThAb.nii.gz', '10000133_1_CTce_ThAb.nii.gz', '10000134_1_CTce_ThAb.nii.gz', '10000135_1_CTce_ThAb.nii.gz', '10000136_1_CTce_ThAb.nii.gz' ] volumes_path = "/home/qinji/PANet_Visceral/eval_dataset/volumes/" segs_path = "/home/qinji/PANet_Visceral/eval_dataset/segmentations/" support_vol_dict, support_mask_dict = load_vol_and_mask( support_volumes, volumes_path, segs_path, labels=list(val_labels)) query_vol_dict, query_mask_dict = load_vol_and_mask( query_volumes, volumes_path, segs_path, labels=list(val_labels)) print('Successfully Load eval data!') _log.info('###### Set optimizer ######') optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) i_iter = 0 log_loss = {'loss': 0} _log.info('###### Training ######') best_val_dice = 0 model.eval() eval(model, -1, support_vol_dict, support_mask_dict, query_vol_dict, query_mask_dict, list(val_labels), os.path.join(f'{_run.observers[0].dir}/val_logs', 'val_logs.txt')) model.train() for i_iter, sample_batched in enumerate(trainloader): # Prepare input support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) # Forward and Backward optimizer.zero_grad() query_pred = model(support_images, support_fg_mask, support_bg_mask, query_images) query_loss = criterion(query_pred, query_labels) loss = query_loss loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() _run.log_scalar('loss', query_loss) log_loss['loss'] += query_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) print(f'step {i_iter+1}: loss: {loss}') with open( os.path.join(f'{_run.observers[0].dir}/val_logs', 'val_logs.txt'), 'a') as f: f.write(f'step {i_iter+1}: loss: {loss}\n') if (i_iter + 1) % _config['val_interval'] == 0: _log.info('###### Validing ######') model.eval() average_labels_dice = eval( model, i_iter, support_vol_dict, support_mask_dict, query_vol_dict, query_mask_dict, list(val_labels), os.path.join(f'{_run.observers[0].dir}/val_logs', 'val_logs.txt')) model.train() print(f'step {i_iter+1}: val_average_dice: {average_labels_dice}') if average_labels_dice > best_val_dice: _log.info('###### Taking snapshot ######') best_val_dice = average_labels_dice torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', 'best_val.pth')) _log.info('###### Saving final model ######') torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],]) model.train() # Using Saliency # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2netp' + '/' + 'u2netp.pth' # u2_net = U2NETP(3,1) # u2_net.load_state_dict(torch.load(u2_model_dir)) # if torch.cuda.is_available(): # u2_net.cuda() _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot elif data_name == 'COCO': make_data = coco_fewshot else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][_config['label_sets']] transforms = Compose([Resize(size=_config['input_size']), RandomMirror()]) dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'] ) trainloader = DataLoader( dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=1, pin_memory=True, drop_last=True ) _log.info('###### Set optimizer ######') print(_config['mode']) optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) i_iter = 0 log_loss = {'loss': 0, 'align_loss': 0, 'dist_loss': 0} _log.info('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [query_image.cuda() for query_image in sample_batched['query_images']] query_labels = torch.cat( [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0) # Forward and Backward optimizer.zero_grad() # with torch.no_grad(): # # u2net # inputs = query_images[0].type(torch.FloatTensor) # labels = query_labels.type(torch.FloatTensor) # if torch.cuda.is_available(): # inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), # requires_grad=False) # else: # inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False) # d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v) # # normalization # pred = d1[:,0,:,:] # pred = normPRED(pred) pred = [] query_pred, align_loss, dist_loss = model(support_images, support_fg_mask, support_bg_mask, query_images, pred) query_loss = criterion(query_pred, query_labels) loss = query_loss + dist_loss + align_loss * 0.2 #_config['align_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() dist_loss = dist_loss.detach().data.cpu().numpy() if dist_loss != 0 else 0 align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0 _run.log_scalar('loss', query_loss) _run.log_scalar('align_loss', align_loss) _run.log_scalar('dist_loss', dist_loss) log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss log_loss['dist_loss'] += dist_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, dist_loss: {dist_loss}') if (i_iter + 1) % _config['save_pred_every'] == 0: _log.info('###### Taking snapshot ######') torch.save(model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) _log.info('###### Saving final model ######') torch.save(model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') print(f"experiment : {_run.experiment_info['name']} , ex_ID : {_run._id}") set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True device = torch.device(f"cuda:{_config['gpu_id']}") # torch.cuda.set_device(device=_config['gpu_id']) # torch.set_num_threads(1) model = MedicalFSS(_config, device).to(device) _log.info('###### Load data ######') make_data = meta_data tr_dataset, val_dataset, ts_dataset = make_data(_config) trainloader = DataLoader( dataset=tr_dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['n_work'], pin_memory=False, #True load data while training gpu drop_last=True) validationloader = DataLoader( dataset=val_dataset, batch_size=1, # batch_size=_config['batch_size'], shuffle=True, num_workers=_config['n_work'], pin_memory=False, #True drop_last=False) # all_samples = test_loader_Spleen(split=1) # for iterative validation _log.info('###### Set optimizer ######') print(_config['optim']) # optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) optimizer = torch.optim.Adam(list(model.parameters()), _config['optim']['lr']) # scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) # criterion_ce = nn.CrossEntropyLoss() # criterion = losses.DiceLoss() criterion = nn.BCELoss() if _config['record']: ## tensorboard visualization _log.info('###### define tensorboard writer #####') _log.info(f'##### board/train_{_config["board"]}_{date()}') writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}') i_iter = 0 log_loss = {'loss': 0, 'align_loss': 0} min_val_loss = 100000.0 min_iter = 0 min_epoch = 0 iter_n_train, iter_n_val = len(trainloader), len(validationloader) _log.info('###### Training ######') q_slice_n = _config['q_slice'] blank = torch.zeros([1, 256, 256]).to(device) iter_print = _config['iter_print'] for i_epoch in range(_config['n_steps']): loss_epoch = 0 ## training stage for i_iter, sample_train in enumerate(trainloader): preds = [] loss_per_video = 0.0 optimizer.zero_grad() s_x = sample_train['s_x'].to( device) # [B, Support, slice_num, 1, 256, 256] s_y = sample_train['s_y'].to( device) # [B, Support, slice_num, 1, 256, 256] preds = model(s_x) for frame_id in range(q_slice_n): s_yi = s_y[:, 0, frame_id, 0, :, :] # [B, 1, 256, 256] yhati = preds[frame_id] # pdb.set_trace() # loss = criterion(F.softmax(yhati, dim=1), s_yi2)+criterion_ce(F.softmax(yhati, dim=1), s_yi2) # loss = criterion(F.softmax(yhati, dim=1), s_yi2) loss = criterion(yhati, s_yi) loss_per_video += loss preds.append(yhati) loss_per_video.backward() optimizer.step() loss_epoch += loss_per_video if iter_print: print( f"train, iter:{i_iter}/{iter_n_train}, iter_loss:{loss_per_video}", end='\r') if _config['record'] and i_iter == 0: batch_i = 0 frames = [] for frame_id in range(0, q_slice_n): # query_pred = output.argmax(dim=1) frames += overlay_color(q_x[batch_i, frame_id], preds[frame_id][batch_i].round(), s_y[batch_i, frame_id], scale=_config['scale']) for frame_id in range(0, q_slice_n): frames += overlay_color(s_x[batch_i, 0, frame_id], blank, s_y[batch_i, 0, frame_id], scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image("train/visual", visual, i_epoch) with torch.no_grad(): ## validation stage loss_valid = 0 preds = [] for i_iter, sample_valid in enumerate(validationloader): loss_per_video = 0.0 optimizer.zero_grad() s_x = sample_valid['s_x'].to( device) # [B, slice_num, 1, 256, 256] s_y = sample_valid['s_y'].to( device) # [B, slice_num, 1, 256, 256] preds = model(s_x) for frame_id in range(q_slice_n): s_yi = s_y[:, 0, frame_id, 0, :, :] # [B, 1, 256, 256] # s_yi2 = s_yi.squeeze(1) # [B, 256, 256] yhati = preds[frame_id] # loss = criterion(F.softmax(yhati, dim=1), s_yi2) + criterion_ce(F.softmax(yhati, dim=1), s_yi2) # loss = criterion(F.softmax(yhati, dim=1), s_yi2) loss = criterion(yhati, s_yi) loss_per_video += loss preds.append(yhati) loss_valid += loss_per_video if iter_print: print( f"valid, iter:{i_iter}/{iter_n_val}, iter_loss:{loss_per_video}", end='\r') if _config['record'] and i_iter == 0: batch_i = 0 frames = [] for frame_id in range(0, q_slice_n): frames += overlay_color( q_x[batch_i, frame_id], preds[frame_id][batch_i].round(), s_y[batch_i, frame_id], scale=_config['scale']) for frame_id in range(0, q_slice_n): frames += overlay_color(s_x[batch_i, 0, frame_id], blank, s_y[batch_i, 0, frame_id], scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image("valid/visual", visual, i_epoch) if min_val_loss > loss_valid: min_epoch = i_epoch min_val_loss = loss_valid print( f"train - epoch:{i_epoch}/{_config['n_steps']}, epoch_loss:{loss_epoch} valid_loss:{loss_valid} \t => model saved", end='\n') save_fname = f'{_run.observers[0].dir}/snapshots/lowest.pth' else: print( f"train - epoch:{i_epoch}/{_config['n_steps']}, epoch_loss:{loss_epoch} valid_loss:{loss_valid} - min epoch:{min_epoch}", end='\n') save_fname = f'{_run.observers[0].dir}/snapshots/last.pth' _run.log_scalar("training.loss", float(loss_epoch), i_epoch) _run.log_scalar("validation.loss", float(loss_valid), i_epoch) _run.log_scalar("min_epoch", min_epoch, i_epoch) if _config['record']: writer.add_scalar('loss/train_loss', loss_epoch, i_epoch) writer.add_scalar('loss/valid_loss', loss_valid, i_epoch) torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), }, save_fname) writer.close()
def main(config): if not os.path.exists(config['snapshots']): os.makedirs(config['snapshots']) snap_shots_dir = config['snapshots'] # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) torch.cuda.set_device(2) # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" set_seed(config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.set_num_threads(1) model = FewShotSeg(cfg=config['model']) model = nn.DataParallel(model.cuda(), device_ids=[2]) model.train() data_name = config['dataset'] if data_name == 'davis': make_data = davis2017_fewshot else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][config['label_sets']] transforms = Compose([Resize(size=config['input_size']), RandomMirror()]) dataset = make_data(base_dir=config['path'][data_name]['data_dir'], split=config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=config['n_steps'] * config['batch_size'], n_ways=config['task']['n_ways'], n_shots=config['task']['n_shots'], n_queries=config['task']['n_queries']) trainloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=1, pin_memory=True, drop_last=True) optimizer = torch.optim.SGD(model.parameters(), **config['optim']) scheduler = MultiStepLR(optimizer, milestones=config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label']) i_iter = 0 log_loss = {'loss': 0, 'align_loss': 0} for i_iter, sample_batched in enumerate(trainloader): # Prepare input support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_fg_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_bg_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) # print(query_labels.shape) pre_masks = [ query_label.float().cuda() for query_label in sample_batched['query_masks'] ] # Forward and Backward optimizer.zero_grad() query_pred, align_loss = model(support_images, support_fg_mask, support_bg_mask, query_images, pre_masks) query_loss = criterion(query_pred, query_labels) loss = query_loss + align_loss * config['align_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy( ) if align_loss != 0 else 0 # _run.log_scalar('loss', query_loss) # _run.log_scalar('align_loss', align_loss) log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss # print loss and take snapshots if (i_iter + 1) % config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}') if (i_iter + 1) % config['save_pred_every'] == 0: torch.save(model.state_dict(), os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth')) torch.save(model.state_dict(), os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
ray.init() @ray.remote def encode(image, label, transform): image = transform(image) return { 'encoded_image': encode_image(image), 'label': label, } if __name__ == '__main__': set_seed(args.seed) transform = transforms.Compose([ transforms.Resize(256, interpolation=Image.CUBIC), ]) timer = Timer() data_info = {} for dataset_name in [ 'Food-101', 'CUB-200', 'Caltech-256', 'DTD', 'Flowers-102', 'Pet', 'Cars', 'Dogs' ]: train_dataset, validation_dataset, test_dataset, classes = get_dataset_func( dataset_name)(args.data_dir, transform, transform)
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = Encoder(pretrained_path=_config['path']['init_path']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) model.train() _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot elif data_name == 'COCO': make_data = coco_fewshot else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][_config['label_sets']] transforms = Compose([Resize(size=_config['input_size']), RandomMirror()]) dataset = make_data(base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, label_sets=_config['label_sets'], max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) trainloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=1, pin_memory=True, drop_last=True) _log.info('###### Set optimizer ######') optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) i_iter = 0 log_loss = {'loss': 0, 'mcl_loss': 0} _log.info('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): #support image,support mask label and support multi-class label support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_images = torch.cat( [torch.cat(way, dim=0) for way in support_images], dim=0) support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_fg_mask = torch.cat( [torch.cat(way, dim=0) for way in support_fg_mask], dim=0) support_label_mc = [[shot[f'label_ori'].long().cuda() for shot in way] for way in sample_batched['support_mask']] support_label_mc = torch.cat( [torch.cat(way, dim=0) for way in support_label_mc], dim=0) #query image,query mask label and query multi-class label query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_images = torch.cat(query_images, dim=0) query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_label_mc = [ n_queries[f'label_ori'].long().cuda() for n_queries in sample_batched['query_masks'] ] query_label_mc = torch.cat(query_label_mc, dim=0) optimizer.zero_grad() query_pred, support_pred_mc, query_pred_mc, support_pred = model( support_images, query_images, support_fg_mask) query_pred = F.interpolate(query_pred, size=query_images.shape[-2:], mode='bilinear') support_pred = F.interpolate(support_pred, size=support_images.shape[-2:], mode='bilinear') support_pred_mc = F.interpolate(support_pred_mc, size=support_images.shape[-2:], mode='bilinear') query_pred_mc = F.interpolate(query_pred_mc, size=query_images.shape[-2:], mode='bilinear') binary_loss = criterion(query_pred, query_labels) + criterion( support_pred, support_fg_mask.long().cuda()) mcl_loss = criterion(support_pred_mc, support_label_mc) + criterion( query_pred_mc, query_label_mc) loss = binary_loss + mcl_loss * _config['mcl_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss binary_loss = binary_loss.detach().data.cpu().numpy() mcl_loss = mcl_loss.detach().data.cpu().numpy() if mcl_loss != 0 else 0 _run.log_scalar('loss', binary_loss) _run.log_scalar('mcl_loss', mcl_loss) log_loss['loss'] += binary_loss log_loss['mcl_loss'] += mcl_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) mcl_loss = log_loss['mcl_loss'] / (i_iter + 1) print(f'step {i_iter+1}: loss: {loss}, mcl_loss: {mcl_loss}') if (i_iter + 1) % _config['save_pred_every'] == 0: _log.info('###### Taking snapshot ######') torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) _log.info('###### Saving final model ######') torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
def main(_run, _config, _log): os.makedirs(f'{_run.observers[0].dir}/features', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model'], task=_config['task']) model = nn.DataParallel(model.cuda(), device_ids=[_config['gpu_id'],]) if not _config['notrain']: model.load_state_dict(torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][_config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) features_dfs = [] _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'] ) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) support_ids = [[sample_batched['support_ids'][way*_config['task']['n_shots'] + shot][0] for shot in range(_config['task']['n_shots'])] for way in range(_config['task']['n_ways'])] support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [query_image.cuda() for query_image in sample_batched['query_images']] query_labels = torch.cat( [query_label.cuda()for query_label in sample_batched['query_labels']], dim=0) query_pred, _, supp_fts = model(support_images, support_fg_mask, support_bg_mask, query_images) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) # Save features row for i, label_id in enumerate(label_ids): lbl_df = pd.DataFrame(torch.cat(supp_fts[i]).cpu().numpy()) lbl_df['label'] = label_id.item() lbl_df['id'] = pd.Series(support_ids[i]) features_dfs.append(lbl_df) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') _log.info('### Exporting features CSV') features_df = pd.concat(features_dfs) features_df = features_df.drop_duplicates(subset=['id']) cols = list(features_df) cols = [cols[-1], cols[-2]] + cols[:-2] features_df = features_df[cols] features_df.to_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv', index=False) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() _log.info('###### Saving features visualization ######') all_fts = pd.concat([pd.read_csv(f'{_run.observers[0].dir}/features/features_run_{run+1}.csv') for run in range(_config['n_runs'])]) all_fts = all_fts.drop_duplicates(subset=['id']) _log.info('### Obtaining Umap visualization ###') plot_umap(all_fts, f'{_run.observers[0].dir}/features/Umap_fts.png') _log.info('### Obtaining TSNE visualization ###') plot_tsne(all_fts, f'{_run.observers[0].dir}/features/TSNE_fts.png') _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(_run, _config, _log): # Code for visualization works only for four sets right now, # If we add a dataset with more sets, we should change the visualize script os.makedirs(f'{_run.observers[0].dir}/TSNE', exist_ok=True) os.makedirs(f'{_run.observers[0].dir}/Umap', exist_ok=True) shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) _log.info('###### Prepare data ######') dataset = _config['dataset'] task = _config['task'] model = _config['model'] lbl_sets_paths = [ '_'.join( [f'{ex.path}'] + [ dataset, ] + ['[test]'] + [key for key, value in model.items() if value] + [f'{task["n_ways"]}way_{task["n_shots"]}shot_split{label_set}']) for label_set in range(_config['n_sets']) ] metrics = [] set_dframes = [] split_paths = [] for lbl_set_path in lbl_sets_paths: experiment_id = _config['experiment_id'] if experiment_id == 'last': experiment_id = os.listdir(f'./runs/{lbl_set_path}')[-1] split_paths.append(f'{lbl_set_path}/{experiment_id}/') # Obtaining features set_dframes.append( pd.concat([ pd.read_csv( f'./runs/{lbl_set_path}/{experiment_id}/features/features_run_{run+1}.csv' ) for run in range(_config['n_runs']) ])) # Obtaining metrics with open(f'./runs/{lbl_set_path}/{experiment_id}/metrics.json') as f: data = json.load(f) metrics.append(data['final_meanIoU']['values'][0]) _log.info('###### Saving metrics from all sets ######') experiment_name = '_'.join([f'{ex.path}'] + [ dataset, ] + [key for key, value in model.items() if value] + [f'{task["n_ways"]}way_{task["n_shots"]}shot']) # Save in experiment folder columns = [ 'Name', 'Details', 'Split1', 'Split2', 'Split3', 'Split4', 'MeanIoU', 'n_ways', 'n_shots', 'n_queries', 'TestSeed', 'TrainingSeed', 'Split1_path', 'Split2_path', 'Split3_path', 'Split4_path', ] row = [experiment_name, _config['details']] \ + metrics \ + [sum(metrics) / len(metrics), task["n_ways"], task["n_shots"], task["n_queries"], _config['seed'], _config['train_seed']] \ + split_paths features_df = pd.DataFrame([row], columns=columns) features_df.to_csv(f'{_run.observers[0].dir}/metrics.csv', index=False) if os.path.exists('metrics.csv'): results = pd.read_csv('metrics.csv') results = pd.concat([results, features_df], ignore_index=True) results.sort_values(by=['Name', 'MeanIoU'], inplace=True) else: results = features_df results.to_csv('metrics.csv', index=False) _log.info('###### Plotting visualization with Umap and TSNE ######') combinations = [[0], [1], [2], [3], [0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3], [0, 1, 2], [1, 2, 3], [0, 1, 2, 3]] for combination in combinations: curr_fts = pd.concat([set_dframes[comb] for comb in combination]) curr_fts = curr_fts.drop_duplicates(subset=['id']) comb_str = '-'.join(map(str, combination)) _log.info(f'### Obtaining Umap for label sets {comb_str} ###') plot_umap(curr_fts, f'{_run.observers[0].dir}/Umap/set_{comb_str}_Umap.png') # TSNE _log.info(f'### Obtaining TSNE for label sets {comb_str} ###') plot_tsne(curr_fts, f'{_run.observers[0].dir}/TSNE/set_{comb_str}_TSNE.png')
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) print("Snapshotttt") print(_config['snapshot']) model.eval() # u2_model_dir = '/content/gdrive/My Drive/Research/U-2-Net/saved_models/'+ 'u2net' + '/' + 'u2net_bce_itr_3168_train_1.523160_tar_0.203136.pth' # u2_net = U2NET(3,1) # u2_net.load_state_dict(torch.load(u2_model_dir)) # if torch.cuda.is_available(): # u2_net.cuda() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] if _config['scribble_dilation'] > 0: transforms.append(DilateScribble(size=_config['scribble_dilation'])) transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] suffix = 'scribble' if _config['scribble'] else 'mask' if _config['bbox']: support_fg_mask = [] support_bg_mask = [] for i, way in enumerate(sample_batched['support_mask']): fg_masks = [] bg_masks = [] for j, shot in enumerate(way): fg_mask, bg_mask = get_bbox( shot['fg_mask'], sample_batched['support_inst'][i][j]) fg_masks.append(fg_mask.float().cuda()) bg_masks.append(bg_mask.float().cuda()) support_fg_mask.append(fg_masks) support_bg_mask.append(bg_masks) else: support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) # u2net inputs = query_images[0].type(torch.FloatTensor) labels_v = query_labels.type(torch.FloatTensor) if torch.cuda.is_available(): inputs_v, labels_v = Variable( inputs.cuda(), requires_grad=False), Variable(labels_v.cuda(), requires_grad=False) else: inputs_v, labels_v = Variable( inputs, requires_grad=False), Variable(labels_v, requires_grad=False) #d1,d2,d3,d4,d5,d6,d7= u2_net(inputs_v) # normalization # pred = d1[:,0,:,:] # pred = normPRED(pred) pred = [] query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images, pred) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) crit = nn.NLLLoss(ignore_index=255) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': max_label = 20 elif data_name == 'COCO': max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] exclude_labels = labels_val value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = [ transform.RandScale([0.9, 1.1]), transform.RandRotate([-10, 10], padding=mean, ignore_label=255), transform.RandomGaussianBlur(), transform.RandomHorizontalFlip(), transform.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=255), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ] train_transform = transform.Compose(train_transform) train_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir, data_list=cfg.DATASET.train_list, transform=train_transform, mode='train', \ use_coco=False, use_split_coco=False) train_sampler = None train_loader = torch.utils.data.DataLoader(train_data, batch_size=cfg.TRAIN.n_batch, shuffle=(train_sampler is None), num_workers=cfg.TRAIN.workers, pin_memory=True, sampler=train_sampler, drop_last=True) val_transform = transform.Compose([ transform.Resize(size=cfg.DATASET.input_size[0]), transform.ToTensor(), transform.Normalize(mean=mean, std=std) ]) val_data = dataset.SemData(split=cfg.TASK.fold_idx, shot=cfg.TASK.n_shots, data_root=cfg.DATASET.data_dir, data_list=cfg.DATASET.val_list, transform=val_transform, mode='val', use_coco=False, use_split_coco=False) val_sampler = None val_loader = torch.utils.data.DataLoader(val_data, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=cfg.TRAIN.workers, pin_memory=True, sampler=val_sampler) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) net_objectness.cuda() net_decoder.cuda() # Set up optimizers nets = (net_objectness, net_decoder, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) best_iou = 0 # main loop tic = time.time() i_iter = -1 print('###### Training ######') for epoch in range(0, 200): for _, (input, target) in enumerate(train_loader): # Prepare input i_iter += 1 input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) data_time.update(time.time() - tic) net_objectness.zero_grad() net_decoder.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass feat = net_objectness(input, return_feature_maps=True) pred = net_decoder(feat, segSize=cfg.DATASET.input_size) loss = crit(pred, target) acc = pixel_acc(pred, target) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print( 'Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Ave_Accuracy: {:4.2f}, Accuracy:{:4.2f}, Ave_Loss: {:.6f}, Loss: {:.6f}' .format(i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), acc.data.item() * 100, ave_total_loss.average(), loss.data.item())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') net_objectness.eval() net_decoder.eval() net_decoder.use_softmax = True #for run in range(cfg.VAL.n_runs): for run in range(3): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') #for sample_batched in tqdm.tqdm(testloader): for (input, target, _) in val_loader: input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) feat = net_objectness(input, return_feature_maps=True) query_pred = net_decoder( feat, segSize=cfg.DATASET.input_size) metric.record(np.array( query_pred.argmax(dim=1)[0].cpu()), np.array(target[0].cpu()), labels=None, n_run=run) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU_binary: {best_iou}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') if meanIoU_binary > best_iou: best_iou = meanIoU_binary checkpoint(nets, history, cfg, 'best') net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) net_decoder.use_softmax = False print('Training Done!')
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) model.train() _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'BCV': make_data = meta_data else: print(f"data name : {data_name}") raise ValueError('Wrong config for dataset!') tr_dataset, val_dataset, ts_dataset = make_data(_config) trainloader = DataLoader( dataset=tr_dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['n_work'], pin_memory=False, #True load data while training gpu drop_last=True) _log.info('###### Set optimizer ######') optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) if _config['record']: ## tensorboard visualization _log.info('###### define tensorboard writer #####') _log.info(f'##### board/train_{_config["board"]}_{date()}') writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}') log_loss = {'loss': 0, 'align_loss': 0} _log.info('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input s_x_orig = sample_batched['s_x'].cuda( ) # [B, Support, slice_num=1, 1, 256, 256] s_x = s_x_orig.squeeze(2) # [B, Support, 1, 256, 256] s_y_fg_orig = sample_batched['s_y'].cuda( ) # [B, Support, slice_num, 1, 256, 256] s_y_fg = s_y_fg_orig.squeeze(2) # [B, Support, 1, 256, 256] s_y_fg = s_y_fg.squeeze(2) # [B, Support, 256, 256] s_y_bg = torch.ones_like(s_y_fg) - s_y_fg q_x_orig = sample_batched['q_x'].cuda() # [B, slice_num, 1, 256, 256] q_x = q_x_orig.squeeze(1) # [B, 1, 256, 256] q_y_orig = sample_batched['q_y'].cuda() # [B, slice_num, 1, 256, 256] q_y = q_y_orig.squeeze(1) # [B, 1, 256, 256] q_y = q_y.squeeze(1).long() # [B, 256, 256] s_xs = [[s_x[:, shot, ...] for shot in range(_config["n_shot"])]] s_y_fgs = [[s_y_fg[:, shot, ...] for shot in range(_config["n_shot"])]] s_y_bgs = [[s_y_bg[:, shot, ...] for shot in range(_config["n_shot"])]] q_xs = [q_x] """ Args: supp_imgs: support images way x shot x [B x 1 x H x W], list of lists of tensors fore_mask: foreground masks for support images way x shot x [B x H x W], list of lists of tensors back_mask: background masks for support images way x shot x [B x H x W], list of lists of tensors qry_imgs: query images N x [B x 1 x H x W], list of tensors qry_pred: [B, 2, H, W] """ # Forward and Backward optimizer.zero_grad() query_pred, align_loss = model(s_xs, s_y_fgs, s_y_bgs, q_xs) #[B, 2, w, h] query_loss = criterion(query_pred, q_y) loss = query_loss + align_loss * _config['align_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy( ) if align_loss != 0 else 0 _run.log_scalar('loss', query_loss) _run.log_scalar('align_loss', align_loss) log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}') if _config['record']: batch_i = 0 frames = [] query_pred = query_pred.argmax(dim=1) query_pred = query_pred.unsqueeze(1) frames += overlay_color(q_x_orig[batch_i, 0], query_pred[batch_i].float(), q_y_orig[batch_i, 0]) visual = make_grid(frames, normalize=True, nrow=2) writer.add_image("train/visual", visual, i_iter) print(f"train - iter:{i_iter} \t => model saved", end='\n') save_fname = f'{_run.observers[0].dir}/snapshots/last.pth' torch.save(model.state_dict(), save_fname)
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) device = torch.device(f"cuda:{_config['gpu_id']}") _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data q_slice_n = _config['q_slice'] iter_print = _config['iter_print'] if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) if _config["n_update"]: _log.info('###### fine tuning with support data of target organ #####') _config["n_shot"] = _config["n_shot"] - 1 _log.info('###### Create model ######') model = MedicalFSS(_config, device).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') print("checkpoint keys : ", checkpoint.keys()) # initializer.load_state_dict(checkpoint['initializer']) model.load_state_dict(checkpoint['model']) tr_dataset, val_dataset, ts_dataset = make_data(_config, is_finetuning=True) trainloader = DataLoader(dataset=tr_dataset, batch_size=1, shuffle=False, pin_memory=False, drop_last=False) optimizer = torch.optim.Adam(list(model.parameters()), _config['optim']['lr']) # criterion_ce = nn.CrossEntropyLoss() # criterion = losses.DiceLoss() criterion = nn.BCELoss() for i_iter, sample_train in enumerate(trainloader): preds = [] loss_per_video = 0.0 optimizer.zero_grad() s_x = sample_train['s_x'].to( device) # [B, Support, slice_num, 1, 256, 256] s_y = sample_train['s_y'].to( device) # [B, Support, slice_num, 1, 256, 256] preds = model(s_x) for frame_id in range(q_slice_n): s_yi = s_y[:, 0, frame_id, 0, :, :] # [B, 1, 256, 256] yhati = preds[frame_id] # loss = criterion(F.softmax(yhati, dim=1), q_yi2) # loss = criterion(F.softmax(yhati, dim=1), q_yi2)+criterion_ce(F.softmax(yhati, dim=1), q_yi2) loss = criterion(yhati, s_yi) loss_per_video += loss preds.append(yhati) loss_per_video.backward() optimizer.step() if iter_print: print( f"train, iter:{i_iter}/{_config['n_update']}, iter_loss:{loss_per_video}", end='\r') _config["n_shot"] = _config["n_shot"] + 1 else: _log.info('###### Create model ######') model = MedicalFSS(_config, device).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') print("checkpoint keys : ", checkpoint.keys()) # initializer.load_state_dict(checkpoint['initializer']) model.load_state_dict(checkpoint['model']) model.n_shot = _config["n_shot"] tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader(dataset=ts_dataset, batch_size=1, shuffle=False, pin_memory=False, drop_last=False) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) blank = torch.zeros([1, 256, 256]).to(device) reversed_idx = list(reversed(range(q_slice_n))) ch = 256 # number of channels of embedding img_lists = [] pred_lists = [] label_lists = [] saves = {} n_test = len(ts_dataset.q_cnts) for subj_idx in range(n_test): saves[subj_idx] = [] with torch.no_grad(): batch_idx = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list, pred_list, label_list, preds = [], [], [], [] s_x = sample_test['s_x'].to( device) # [B, Support, slice_num, 1, 256, 256] s_y = sample_test['s_y'].to( device) # [B, Support, slice_num, 1, 256, 256] preds = model(s_x) for frame_id in range(q_slice_n): s_xi = s_x[:, 0, frame_id, :, :, :] # only 1 shot in upperbound model s_yi = s_y[:, 0, frame_id, :, :, :] # [B, 1, 256, 256] yhati = preds[frame_id] # pdb.set_trace() preds.append(yhati.round()) img_list.append(s_xi[batch_idx].cpu().numpy()) pred_list.append(yhati.round().cpu().numpy()) label_list.append(s_yi[batch_idx].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list]) if iter_print: print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) # if _config['record']: # frames = [] # for frame_id in range(0, q_x.size(1)): # frames += overlay_color(q_x[batch_idx, frame_id], preds[frame_id-1][batch_idx].round(), q_y[batch_idx, frame_id], scale=_config['scale']) # visual = make_grid(frames, normalize=True, nrow=5) # writer.add_image(f"test/{subj_idx}/{idx}_query_image", visual, i) center_idx = (q_slice_n // 2) + 1 - 1 # 5->2 index dice_similarities = [] for subj_idx in range(n_test): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): subj_idx, idx, img_list, pred_list, label_list = save_subj[i] # if idx==(q_slice_n//2): if idx == 0: for j in range((q_slice_n // 2) + 1): # 5//2 + 1 = 3 imgs.append(img_list[idx + j]) preds.append(pred_list[idx + j]) labels.append(label_list[idx + j]) elif idx == (len(save_subj) - 1): # pdb.set_trace() for j in range((q_slice_n // 2) + 1): # 5//2 + 1 = 3 imgs.append(img_list[center_idx + j]) preds.append(pred_list[center_idx + j]) labels.append(label_list[center_idx + j]) else: imgs.append(img_list[center_idx]) preds.append(pred_list[center_idx]) labels.append(label_list[center_idx]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) dice = np.sum([label_arr * pred_arr ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) dice_similarities.append(dice) print( f"{len(imgs)} slice -> computing dice scores. {subj_idx}/{n_test}. {ts_dataset.q_cnts[subj_idx] }/{len(save_subj)} => {len(imgs)}", end='\r') if _config['record']: frames = [] for frame_id in range(0, len(imgs)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]), torch.tensor(labels[frame_id]), scale=_config['scale']) print(len(frames)) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True device = torch.device(f"cuda:{_config['gpu_id']}") resize_dim = _config['input_size'] encoded_h = int(resize_dim[0] / 2**_config['n_pool']) encoded_w = int(resize_dim[1] / 2**_config['n_pool']) s_encoder = SupportEncoder(_config['path']['init_path'], device) #.to(device) q_encoder = QueryEncoder(_config['path']['init_path'], device) #.to(device) decoder = Decoder(input_res=(encoded_h, encoded_w), output_res=resize_dim).to(device) _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'prostate': make_data = meta_data else: raise ValueError('Wrong config for dataset!') tr_dataset, val_dataset, ts_dataset = make_data(_config) trainloader = DataLoader( dataset=tr_dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['n_work'], pin_memory=False, #True load data while training gpu drop_last=True) _log.info('###### Set optimizer ######') print(_config['optim']) optimizer = torch.optim.Adam( #list(initializer.parameters()) + list(s_encoder.parameters()) + list(q_encoder.parameters()) + list(decoder.parameters()), _config['optim']['lr']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) pos_weight = torch.tensor([0.3, 1], dtype=torch.float).to(device) criterion = nn.BCELoss() if _config['record']: ## tensorboard visualization _log.info('###### define tensorboard writer #####') _log.info(f'##### board/train_{_config["board"]}_{date()}') writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}') iter_n_train = len(trainloader) _log.info('###### Training ######') for i_epoch in range(_config['n_steps']): loss_epoch = 0 blank = torch.zeros([1, 256, 256]).to(device) for i_iter, sample_train in enumerate(trainloader): ## training stage optimizer.zero_grad() s_x = sample_train['s_x'].to( device) # [B, Support, slice_num, 1, 256, 256] s_y = sample_train['s_y'].to( device) # [B, Support, slice_num, 1, 256, 256] q_x = sample_train['q_x'].to(device) #[B, slice_num, 1, 256, 256] q_y = sample_train['q_y'].to(device) #[B, slice_num, 1, 256, 256] # loss_per_video = 0.0 s_xi = s_x[:, :, 0, :, :, :] # [B, Support, 1, 256, 256] s_yi = s_y[:, :, 0, :, :, :] # for s_idx in range(_config["n_shot"]): s_x_merge = s_xi.view(s_xi.size(0) * s_xi.size(1), 1, 256, 256) s_y_merge = s_yi.view(s_yi.size(0) * s_yi.size(1), 1, 256, 256) s_xi_encode_merge, _ = s_encoder(s_x_merge, s_y_merge) # [B*S, 512, w, h] s_xi_encode = s_xi_encode_merge.view(s_yi.size(0), s_yi.size(1), 512, encoded_w, encoded_h) s_xi_encode_avg = torch.mean(s_xi_encode, dim=1) # s_xi_encode, _ = s_encoder(s_xi, s_yi) # [B, 512, w, h] q_xi = q_x[:, 0, :, :, :] q_yi = q_y[:, 0, :, :, :] q_xi_encode, q_ft_list = q_encoder(q_xi) sq_xi = torch.cat((s_xi_encode_avg, q_xi_encode), dim=1) yhati = decoder(sq_xi, q_ft_list) # [B, 1, 256, 256] loss = criterion(yhati, q_yi) # loss_per_video += loss # loss_per_video.backward() loss.backward() optimizer.step() loss_epoch += loss print(f"train, iter:{i_iter}/{iter_n_train}, iter_loss:{loss}", end='\r') if _config['record'] and i_iter == 0: batch_i = 0 frames = [] frames += overlay_color(q_xi[batch_i], yhati[batch_i].round(), q_yi[batch_i], scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=2) writer.add_image("train/visual", visual, i_epoch) if _config['record'] and i_iter == 0: batch_i = 0 frames = [] frames += overlay_color(q_xi[batch_i], yhati[batch_i].round(), q_yi[batch_i], scale=_config['scale']) # frames += overlay_color(s_xi[batch_i], blank, s_yi[batch_i], scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image("valid/visual", visual, i_epoch) print( f"train - epoch:{i_epoch}/{_config['n_steps']}, epoch_loss:{loss_epoch}", end='\n') save_fname = f'{_run.observers[0].dir}/snapshots/last.pth' _run.log_scalar("training.loss", float(loss_epoch), i_epoch) if _config['record']: writer.add_scalar('loss/train_loss', loss_epoch, i_epoch) torch.save( { 's_encoder': s_encoder.state_dict(), 'q_encoder': q_encoder.state_dict(), 'decoder': decoder.state_dict(), 'optimizer': optimizer.state_dict(), }, save_fname) writer.close()
def main(_run, _config, _log): settings = Settings() common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], settings[ 'NETWORK'], settings['TRAINING'], settings['EVAL'] if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'BCV': make_data = meta_data else: print(f"data name : {data_name}") raise ValueError('Wrong config for dataset!') tr_dataset, val_dataset, ts_dataset = make_data(_config) trainloader = DataLoader( dataset=tr_dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['n_work'], pin_memory=False, #True load data while training gpu drop_last=True ) _log.info('###### Create model ######') model = fs.FewShotSegmentorDoubleSDnet(net_params).cuda() model.train() _log.info('###### Set optimizer ######') optim = torch.optim.Adam optim_args = {"lr": train_params['learning_rate'], "weight_decay": train_params['optim_weight_decay'],} # "momentum": train_params['momentum']} optim_c = optim(list(model.conditioner.parameters()), **optim_args) optim_s = optim(list(model.segmentor.parameters()), **optim_args) scheduler_s = lr_scheduler.StepLR(optim_s, step_size=100, gamma=0.1) scheduler_c = lr_scheduler.StepLR(optim_c, step_size=100, gamma=0.1) criterion = losses.DiceLoss() if _config['record']: ## tensorboard visualization _log.info('###### define tensorboard writer #####') _log.info(f'##### board/train_{_config["board"]}_{date()}') writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}') iter_print = _config["iter_print"] iter_n_train = len(trainloader) _log.info('###### Training ######') for i_epoch in range(_config['n_steps']): epoch_loss = 0 for i_iter, sample_batched in enumerate(trainloader): # Prepare input s_x = sample_batched['s_x'].cuda() # [B, Support, slice_num=1, 1, 256, 256] X = s_x.squeeze(2) # [B, Support, 1, 256, 256] s_y = sample_batched['s_y'].cuda() # [B, Support, slice_num, 1, 256, 256] Y = s_y.squeeze(2) # [B, Support, 1, 256, 256] Y = Y.squeeze(2) # [B, Support, 256, 256] q_x = sample_batched['q_x'].cuda() # [B, slice_num, 1, 256, 256] query_input = q_x.squeeze(1) # [B, 1, 256, 256] q_y = sample_batched['q_y'].cuda() # [B, slice_num, 1, 256, 256] y2 = q_y.squeeze(1) # [B, 1, 256, 256] y2 = y2.squeeze(1) # [B, 256, 256] y2 = y2.type(torch.LongTensor).cuda() entire_weights = [] for shot_id in range(_config["n_shot"]): input1 = X[:, shot_id, ...] # use 1 shot at first y1 = Y[:, shot_id, ...] # use 1 shot at first condition_input = torch.cat((input1, y1.unsqueeze(1)), dim=1) weights = model.conditioner(condition_input) # 2, 10, [B, channel=1, w, h] entire_weights.append(weights) # pdb.set_trace() avg_weights=[[],[None, None, None, None]] for i in range(9): weight_cat = torch.cat([weights[0][i] for weights in entire_weights],dim=1) avg_weight = torch.mean(weight_cat,dim=1,keepdim=True) avg_weights[0].append(avg_weight) avg_weights[0].append(None) output = model.segmentor(query_input, avg_weights) loss = criterion(F.softmax(output, dim=1), y2) optim_s.zero_grad() optim_c.zero_grad() loss.backward() optim_s.step() optim_c.step() epoch_loss += loss if iter_print: print(f"train, iter:{i_iter}/{iter_n_train}, iter_loss:{loss}", end='\r') scheduler_c.step() scheduler_s.step() print(f'step {i_epoch+1}: loss: {epoch_loss} ') if _config['record']: batch_i = 0 frames = [] query_pred = output.argmax(dim=1) query_pred = query_pred.unsqueeze(1) frames += overlay_color(q_x[batch_i,0], query_pred[batch_i].float(), q_y[batch_i,0]) # frames += overlay_color(s_xi[batch_i], blank, s_yi[batch_i], scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=2) writer.add_image("train/visual", visual, i_epoch) save_fname = f'{_run.observers[0].dir}/snapshots/last.pth' torch.save(model.state_dict(),save_fname)
def main(cfg, gpus): torch.cuda.set_device(gpus[0]) # Network Builders net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_enc_memory = ModelBuilder.build_encoder_memory_separate( arch=cfg.MODEL.arch_memory_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.TASK.n_ways+1, RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate) net_att_query = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.fc_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_projection = ModelBuilder.build_projection( arch=cfg.MODEL.arch_projection, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.projection_dim, weights=cfg.MODEL.weights_projection) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.decoder_fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=cfg.TASK.n_ways+1, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout, use_softmax=True) if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder: '''net_objectness = ModelBuilder.build_objectness( arch='resnet50_deeplab', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='aspp_few_shot', input_dim=2048, fc_dim=256, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, dropout_rate=0.5, use_dropout=True)''' net_objectness = ModelBuilder.build_objectness( arch='hrnetv2', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='c1_nodropout', input_dim=720, fc_dim=720, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, use_dropout=False) for param in net_objectness.parameters(): param.requires_grad = False for param in net_objectness_decoder.parameters(): param.requires_grad = False else: net_objectness = None net_objectness_decoder = None crit = nn.NLLLoss(ignore_index=255) segmentation_module = SegmentationAttentionSeparateModule(net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, debug=cfg.is_debug or cfg.eval_att_voting, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL.objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL.linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply) segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) segmentation_module.cuda() segmentation_module.eval() print('###### Prepare data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized import coco_fewshot make_data = coco_fewshot max_label = 80 split = cfg.DATASET.data_split + '2014' annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json' cocoapi = COCO(annFile) else: raise ValueError('Wrong config for dataset!') #labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] transforms = [Resize_test(size=cfg.DATASET.input_size)] transforms = Compose(transforms) print('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load data ###') dataset = make_data( base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[] ) if data_name == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of Data: {len(dataset)}") count = 0 if cfg.multi_scale_test: scales = [224, 328, 424] else: scales = [328] for sample_batched in tqdm.tqdm(testloader): feed_dict = data_preprocess(sample_batched, cfg) if data_name == 'COCO': label_ids = [coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids']] else: label_ids = list(sample_batched['class_ids']) for q, scale in enumerate(scales): if len(scales) > 1: feed_dict['img_data'] = nn.functional.interpolate(feed_dict['img_data'].cuda(), size=(scale, scale), mode='bilinear') if cfg.eval_att_voting or cfg.is_debug: query_pred, qread, qval, qk_b, mk_b, mv_b, p, feature_enc, feature_memory = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2])) if cfg.eval_att_voting: height, width = qread.shape[-2], qread.shape[-1] assert p.shape[0] == height*width img_refs_mask_resize = nn.functional.interpolate(feed_dict['img_refs_mask'][0].cuda(), size=(height, width), mode='nearest') img_refs_mask_resize_flat = img_refs_mask_resize[:,0,:,:].view(img_refs_mask_resize.shape[0], -1) mask_voting_flat = torch.mm(img_refs_mask_resize_flat, p) mask_voting = mask_voting_flat.view(mask_voting_flat.shape[0], height, width) mask_voting = torch.unsqueeze(mask_voting, 0) query_pred = nn.functional.interpolate(mask_voting[:,0:-1], size=cfg.DATASET.input_size, mode='bilinear', align_corners=False) if cfg.is_debug: np.save('debug/img_refs_mask-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), img_refs_mask_resize.detach().cpu().float().numpy()) np.save('debug/query_pred-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), query_pred.detach().cpu().float().numpy()) if cfg.is_debug: np.save('debug/qread-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qread.detach().cpu().float().numpy()) np.save('debug/qval-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qval.detach().cpu().float().numpy()) #np.save('debug/qk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), qk_b.detach().cpu().float().numpy()) #np.save('debug/mk_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mk_b.detach().cpu().float().numpy()) #np.save('debug/mv_b-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), mv_b.detach().cpu().float().numpy()) #np.save('debug/p-%04d-%s-%s.npy'%(count, sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), p.detach().cpu().float().numpy()) #np.save('debug/feature_enc-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_enc[-1].detach().cpu().float().numpy()) #np.save('debug/feature_memory-%s-%s.npy'%(sample_batched['query_ids'][0][0], sample_batched['support_ids'][0][0][0]), feature_memory[-1].detach().cpu().float().numpy()) else: #query_pred = segmentation_module(feed_dict, segSize=cfg.DATASET.input_size) query_pred = segmentation_module(feed_dict, segSize=(feed_dict['seg_label_noresize'].shape[1], feed_dict['seg_label_noresize'].shape[2])) if q == 0: query_pred_final = query_pred/len(scales) else: query_pred_final += query_pred/len(scales) query_pred = query_pred_final metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label_noresize'][0].cpu()), labels=label_ids, n_run=run) if cfg.VAL.visualize: #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape) #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape) #print(feed_dict['img_data'].cpu().shape) query_name = sample_batched['query_ids'][0][0] support_name = sample_batched['support_ids'][0][0][0] if data_name == 'VOC': img = imread(os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name+'.jpg')) else: query_name = int(query_name) img_meta = cocoapi.loadImgs(query_name)[0] img = imread(os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name'])) #img = imresize(img, cfg.DATASET.input_size) visualize_result( (img, as_numpy(feed_dict['seg_label_noresize'][0].cpu()), '%05d'%(count)), as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())), os.path.join(cfg.DIR, 'result') ) count += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) '''_run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}')''' classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU(labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary() print('----- Final Result -----') print('final_classIoU', classIoU.tolist()) print('final_classIoU_std', classIoU_std.tolist()) print('final_meanIoU', meanIoU.tolist()) print('final_meanIoU_std', meanIoU_std.tolist()) print('final_classIoU_binary', classIoU_binary.tolist()) print('final_classIoU_std_binary', classIoU_std_binary.tolist()) print('final_meanIoU_binary', meanIoU_binary.tolist()) print('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) print(f'classIoU mean: {classIoU}') print(f'classIoU std: {classIoU_std}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'classIoU_binary mean: {classIoU_binary}') print(f'classIoU_binary std: {classIoU_std_binary}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_enc_memory = ModelBuilder.build_encoder_memory_separate( arch=cfg.MODEL.arch_memory_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.TASK.n_ways + 1, RGB_mask_combine_val=cfg.DATASET.RGB_mask_combine_val, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate) net_att_query = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_attention( arch=cfg.MODEL.arch_attention, input_dim=cfg.MODEL.fc_dim, fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_projection = ModelBuilder.build_projection( arch=cfg.MODEL.arch_projection, input_dim=cfg.MODEL.encoder_dim, fc_dim=cfg.MODEL.projection_dim, weights=cfg.MODEL.weights_projection) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.decoder_fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=cfg.TASK.n_ways + 1, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) if cfg.MODEL.weights_objectness and cfg.MODEL.weights_objectness_decoder: '''net_objectness = ModelBuilder.build_objectness( arch='resnet50_deeplab', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='aspp_few_shot', input_dim=2048, fc_dim=256, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, dropout_rate=0.5, use_dropout=True)''' net_objectness = ModelBuilder.build_objectness( arch='hrnetv2', weights=cfg.MODEL.weights_objectness, fix_encoder=True) net_objectness_decoder = ModelBuilder.build_decoder( arch='c1_nodropout', input_dim=720, fc_dim=720, ppm_dim=256, num_class=2, weights=cfg.MODEL.weights_objectness_decoder, use_dropout=False) for param in net_objectness.parameters(): param.requires_grad = False for param in net_objectness_decoder.parameters(): param.requires_grad = False else: net_objectness = None net_objectness_decoder = None crit = nn.NLLLoss(ignore_index=255) segmentation_module = SegmentationAttentionSeparateModule( net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, net_objectness, net_objectness_decoder, crit, zero_memory=cfg.MODEL.zero_memory, random_memory_bias=cfg.MODEL.random_memory_bias, random_memory_nobias=cfg.MODEL.random_memory_nobias, random_scale=cfg.MODEL.random_scale, zero_qval=cfg.MODEL.zero_qval, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar, memory_feature_aggregation=cfg.MODEL.memory_feature_aggregation, memory_noLabel=cfg.MODEL.memory_noLabel, mask_feat_downsample_rate=cfg.MODEL.mask_feat_downsample_rate, att_mat_downsample_rate=cfg.MODEL.att_mat_downsample_rate, objectness_feat_downsample_rate=cfg.MODEL. objectness_feat_downsample_rate, segm_downsampling_rate=cfg.DATASET.segm_downsampling_rate, mask_foreground=cfg.MODEL.mask_foreground, global_pool_read=cfg.MODEL.global_pool_read, average_memory_voting=cfg.MODEL.average_memory_voting, average_memory_voting_nonorm=cfg.MODEL.average_memory_voting_nonorm, mask_memory_RGB=cfg.MODEL.mask_memory_RGB, linear_classifier_support=cfg.MODEL.linear_classifier_support, decay_lamb=cfg.MODEL.decay_lamb, linear_classifier_support_only=cfg.MODEL. linear_classifier_support_only, qread_only=cfg.MODEL.qread_only, feature_as_key=cfg.MODEL.feature_as_key, objectness_multiply=cfg.MODEL.objectness_multiply) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized_objectness_debug import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized_objectness_debug import coco_fewshot make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] if cfg.DATASET.exclude_labels: exclude_labels = labels_val else: exclude_labels = [] transforms = Compose([Resize(size=cfg.DATASET.input_size), RandomMirror()]) dataset = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.TRAIN.permute_labels, exclude_labels=exclude_labels, use_ignore=cfg.use_ignore) trainloader = DataLoader(dataset, batch_size=cfg.TRAIN.n_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) segmentation_module.cuda() # Set up optimizers nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, net_projection, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} segmentation_module.train(not cfg.TRAIN.fix_bn) if net_objectness and net_objectness_decoder: net_objectness.eval() net_objectness_decoder.eval() best_iou = 0 # main loop tic = time.time() print('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input feed_dict = data_preprocess(sample_batched, cfg) data_time.update(time.time() - tic) segmentation_module.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass #print(batch_data) loss, acc = segmentation_module(feed_dict) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') segmentation_module.eval() net_decoder.use_softmax = True for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') dataset_val = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels_val, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[]) if data_name == 'COCO': coco_cls_ids = dataset_val.datasets[ 0].dataset.coco.getCatIds() testloader = DataLoader(dataset_val, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of validation Data: {len(dataset)}") #for sample_batched in tqdm.tqdm(testloader): for sample_batched in testloader: feed_dict = data_preprocess(sample_batched, cfg, is_val=True) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) query_pred = segmentation_module( feed_dict, segSize=cfg.DATASET.input_size) metric.record( np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU( labels=sorted(labels_val), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels_val)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU mean: {best_iou}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') checkpoint(nets, history, cfg, 'latest') if meanIoU > best_iou: best_iou = meanIoU checkpoint(nets, history, cfg, 'best') segmentation_module.train(not cfg.TRAIN.fix_bn) if net_objectness and net_objectness_decoder: net_objectness.eval() net_objectness_decoder.eval() net_decoder.use_softmax = False print('Training Done!')
def main(config): if not os.path.exists(config['snapshots']): os.makedirs(config['snapshots']) palette_path = config['palette_dir'] with open(palette_path) as f: palette = f.readlines() palette = list(np.asarray([[int(p) for p in pal[0:-1].split(' ')] for pal in palette]).reshape(768)) snap_shots_dir = config['snapshots'] # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) torch.cuda.set_device(2) # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" set_seed(config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.set_num_threads(1) model = FewShotSeg(cfg=config['model']) model = nn.DataParallel(model.cuda(),device_ids=[2]) model.train() data_name = config['dataset'] if data_name == 'davis': make_data = davis2017_fewshot else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][config['label_sets']] transforms = Compose([Resize(size=config['input_size']), RandomMirror()]) dataset = make_data( base_dir=config['path'][data_name]['data_dir'], split=config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=config['n_steps'] * config['batch_size'], n_ways=config['task']['n_ways'], n_shots=config['task']['n_shots'], n_queries=config['task']['n_queries'] ) trainloader = DataLoader( dataset, batch_size=config['batch_size'], shuffle=True, num_workers=1, pin_memory=True, drop_last=True ) optimizer = torch.optim.SGD(model.parameters(), **config['optim']) scheduler = MultiStepLR(optimizer, milestones=config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=config['ignore_label']) i_iter = 0 log_loss = {'loss': 0, 'align_loss': 0} for i_iter, sample_batched in enumerate(trainloader): # Prepare input support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_fg_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_bg_mask']] img_size = sample_batched['img_size'] # support_label_t = [[shot.float().cuda() for shot in way] # for way in sample_batched['support_bg_mask']] query_images = [query_image.cuda() for query_image in sample_batched['query_images']] query_labels = torch.cat( [query_label.long().cuda() for query_label in sample_batched['query_labels']], dim=0) pre_masks = [query_label.float().cuda() for query_label in sample_batched['query_masks']] # Forward and Backward optimizer.zero_grad() query_pred, align_loss = model(support_images, support_fg_mask, support_bg_mask, query_images,pre_masks) query_pred = F.interpolate(query_pred, size=img_size, mode= "bilinear") query_loss = criterion(query_pred, query_labels) loss = query_loss + align_loss * config['align_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy() if align_loss != 0 else 0 log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss # print loss and take snapshots if (i_iter + 1) % config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) print(f'step {i_iter + 1}: loss: {loss}, align_loss: {align_loss}') # if len(support_fg_mask)>1: # pred = query_pred.argmax(dim=1, keepdim=True) # pred = pred.data.cpu().numpy() # img = pred[0, 0] # for i in range(img.shape[0]): # for j in range(img.shape[1]): # if img[i][j] > 0: # print(f'{img[i][j]} {len(support_fg_mask)}') # # img_e = Image.fromarray(img.astype('float32')).convert('P') # img_e.putpalette(palette) # img_e.save(os.path.join(config['path']['davis']['data_dir'], '{:05d}.png'.format(i_iter))) if (i_iter + 1) % config['save_pred_every'] == 0: torch.save(model.state_dict(), os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth')) torch.save(model.state_dict(), os.path.join(f'{snap_shots_dir}', f'{i_iter + 1}.pth'))
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=None, cfg=_config['model']) model = model.cuda() model.train() _log.info('###### Load data ######') ### Training set data_name = _config['dataset'] if data_name == 'SABS_Superpix': baseset_name = 'SABS' elif data_name == 'C0_Superpix': raise NotImplementedError baseset_name = 'C0' elif data_name == 'CHAOST2_Superpix': baseset_name = 'CHAOST2' else: raise ValueError(f'Dataset: {data_name} not found') ### Transforms for data augmentation tr_transforms = myaug.transform_with_label( {'aug': myaug.augs[_config['which_aug']]}) assert _config[ 'scan_per_load'] < 0 # by default we load the entire dataset directly test_labels = DATASET_INFO[baseset_name]['LABEL_GROUP'][ 'pa_all'] - DATASET_INFO[baseset_name]['LABEL_GROUP'][ _config["label_sets"]] _log.info( f'###### Labels excluded in training : {[lb for lb in _config["exclude_cls_list"]]} ######' ) _log.info( f'###### Unseen labels evaluated in testing: {[lb for lb in test_labels]} ######' ) tr_parent = SuperpixelDataset( # base dataset which_dataset=baseset_name, base_dir=_config['path'][data_name]['data_dir'], idx_split=_config['eval_fold'], mode='train', min_fg=str( _config["min_fg_data"]), # dummy entry for superpixel dataset transforms=tr_transforms, nsup=_config['task']['n_shots'], scan_per_load=_config['scan_per_load'], exclude_list=_config["exclude_cls_list"], superpix_scale=_config["superpix_scale"], fix_length=_config["max_iters_per_load"] if (data_name == 'C0_Superpix') or (data_name == 'CHAOST2_Superpix') else None) ### dataloaders trainloader = DataLoader(tr_parent, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['num_workers'], pin_memory=True, drop_last=True) _log.info('###### Set optimizer ######') if _config['optim_type'] == 'sgd': optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) else: raise NotImplementedError scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=_config['lr_step_gamma']) my_weight = compose_wt_simple(_config["use_wce"], data_name) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label'], weight=my_weight) i_iter = 0 # total number of iteration n_sub_epoches = _config['n_steps'] // _config[ 'max_iters_per_load'] # number of times for reloading log_loss = {'loss': 0, 'align_loss': 0} _log.info('###### Training ######') for sub_epoch in range(n_sub_epoches): _log.info( f'###### This is epoch {sub_epoch} of {n_sub_epoches} epoches ######' ) for _, sample_batched in enumerate(trainloader): # Prepare input i_iter += 1 # add writers support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[ shot[f'fg_mask'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_mask'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) optimizer.zero_grad() # FIXME: in the model definition, filter out the failure case where pseudolabel falls outside of image or too small to calculate a prototype try: query_pred, align_loss, debug_vis, assign_mats = model( support_images, support_fg_mask, support_bg_mask, query_images, isval=False, val_wsize=None) except: print('Faulty batch detected, skip') continue query_loss = criterion(query_pred, query_labels) loss = query_loss + align_loss loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy( ) if align_loss != 0 else 0 _run.log_scalar('loss', query_loss) _run.log_scalar('align_loss', align_loss) log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / _config['print_interval'] align_loss = log_loss['align_loss'] / _config['print_interval'] log_loss['loss'] = 0 log_loss['align_loss'] = 0 print( f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss},' ) if (i_iter + 1) % _config['save_snapshot_every'] == 0: _log.info('###### Taking snapshot ######') torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) if data_name == 'C0_Superpix' or data_name == 'CHAOST2_Superpix': if (i_iter + 1) % _config['max_iters_per_load'] == 0: _log.info('###### Reloading dataset ######') trainloader.dataset.reload_buffer() print( f'###### New dataset with {len(trainloader.dataset)} slices has been loaded ######' ) if (i_iter - 2) > _config['n_steps']: return 1 # finish up
def main(_run, _config, _log): if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) model.train() _log.info('###### Load data ######') data_name = _config['dataset'] if data_name == 'ScanNet': make_data = scannet_fewshot else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][_config['label_sets']] transforms = Compose([Resize(size=_config['input_size']), RandomMirror()]) dataset = make_data(base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) trainloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=20, pin_memory=True, drop_last=True) _log.info('###### Set optimizer ######') optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) i_iter = 0 log_loss = {'loss': 0, 'align_loss': 0} _log.info('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input support_coords = [[shot.cuda() for shot in way] for way in sample_batched['support_coords']] support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_coords = [ query_coord.cuda() for query_coord in sample_batched['query_coords'] ] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) # Forward and Backward optimizer.zero_grad() query_pred, align_loss = model(support_coords, support_images, support_fg_mask, support_bg_mask, query_coords, query_images) query_loss = criterion(query_pred, query_labels[None, ...]) loss = query_loss + align_loss * _config['align_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy( ) if align_loss != 0 else 0 _run.log_scalar('loss', query_loss) _run.log_scalar('align_loss', align_loss) log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) print(f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}') if (i_iter + 1) % _config['save_pred_every'] == 0: _log.info('###### Taking snapshot ######') torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth')) _log.info('###### Saving final model ######') torch.save( model.state_dict(), os.path.join(f'{_run.observers[0].dir}/snapshots', f'{i_iter + 1}.pth'))
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) device = torch.device(f"cuda:{_config['gpu_id']}") _log.info('###### Create model ######') resize_dim = _config['input_size'] encoded_h = int(resize_dim[0] / 2**_config['n_pool']) encoded_w = int(resize_dim[1] / 2**_config['n_pool']) s_encoder = SupportEncoder(_config['path']['init_path'], device) #.to(device) q_encoder = QueryEncoder(_config['path']['init_path'], device) #.to(device) decoder = Decoder(input_res=(encoded_h, encoded_w), output_res=resize_dim).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') s_encoder.load_state_dict(checkpoint['s_encoder']) q_encoder.load_state_dict(checkpoint['q_encoder']) decoder.load_state_dict(checkpoint['decoder']) # initializer.eval() # encoder.eval() # convlstmcell.eval() # decoder.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False) # all_samples = test_loader_Spleen() # all_samples = test_loader_Prostate() if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] fnames = sample_test['q_fname'] s_x = sample_test['s_x'].to(device) # [B, slice_num, 1, 256, 256] s_y = sample_test['s_y'].to(device) # [B, slice_num, 1, 256, 256] q_x = sample_test['q_x'].to(device) # [B, slice_num, 1, 256, 256] q_y = sample_test['q_y'].to(device) # [B, slice_num, 1, 256, 256] s_xi = s_x[:, :, 0, :, :, :] # [B, Support, 1, 256, 256] s_yi = s_y[:, :, 0, :, :, :] for s_idx in range(_config["n_shot"]): s_x_merge = s_xi.view(s_xi.size(0) * s_xi.size(1), 1, 256, 256) s_y_merge = s_yi.view(s_yi.size(0) * s_yi.size(1), 1, 256, 256) s_xi_encode_merge, _ = s_encoder(s_x_merge, s_y_merge) # [B*S, 512, w, h] s_xi_encode = s_xi_encode_merge.view(s_yi.size(0), s_yi.size(1), 512, encoded_w, encoded_h) s_xi_encode_avg = torch.mean(s_xi_encode, dim=1) # s_xi_encode, _ = s_encoder(s_xi, s_yi) # [B, 512, w, h] q_xi = q_x[:, 0, :, :, :] q_yi = q_y[:, 0, :, :, :] q_xi_encode, q_ft_list = q_encoder(q_xi) sq_xi = torch.cat((s_xi_encode_avg, q_xi_encode), dim=1) yhati = decoder(sq_xi, q_ft_list) # [B, 1, 256, 256] preds.append(yhati.round()) img_list.append(q_xi[batch_i].cpu().numpy()) pred_list.append(yhati[batch_i].round().cpu().numpy()) label_list.append(q_yi[batch_i].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list, fnames]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) print("start computing dice similarities ... total ", len(saves)) dice_similarities = [] for subj_idx in range(len(saves)): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i) subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[ i] # print(subj_idx, idx, is_reverse, len(img_list)) # print(i, is_reverse, is_reverse_next, is_flip) for j in range(len(img_list)): imgs.append(img_list[j]) preds.append(pred_list[j]) labels.append(label_list[j]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) # print(ts_dataset.slice_cnts[subj_idx] , len(imgs)) # pdb.set_trace() dice = np.sum([label_arr * pred_arr ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) dice_similarities.append(dice) print(f"computing dice scores {subj_idx}/{10}", end='\n') if _config['record']: frames = [] for frame_id in range(0, len(save_subj)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]), torch.tensor(labels[frame_id]), scale=_config['scale']) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) if _config['save_sample']: ## only for internal test (BCV - MICCAI2015) sup_idx = _config['s_idx'] target = _config['target'] save_name = _config['save_name'] dirs = ["gt", "pred", "input"] save_dir = f"/user/home2/soopil/tmp/PANet/MICCAI2015/sample/fss1000_organ{target}_sup{sup_idx}_{save_name}" for dir in dirs: try: os.makedirs(os.path.join(save_dir, dir)) except: pass subj_name = fnames[0][0].split("/")[-2] if target == 14: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img" orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz" pass else: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img" orig_fname = f"{src_dir}/img{subj_name}.nii.gz" itk = sitk.ReadImage(orig_fname) orig_spacing = itk.GetSpacing() label_arr = label_arr * 2.0 # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])]) # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])]) # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])]) # pdb.set_trace() itk = sitk.GetImageFromArray(label_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(pred_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(img_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz") print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = Encoder(pretrained_path=_config['path']['init_path']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Prepare data ######') data_name = _config['dataset'] if data_name == 'VOC': make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] if _config['scribble_dilation'] > 0: transforms.append(DilateScribble(size=_config['scribble_dilation'])) transforms = Compose(transforms) _log.info('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(_config['n_runs']): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, label_sets=_config['label_sets'], max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries']) if _config['dataset'] == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=1, pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): if _config['dataset'] == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_images = torch.cat( [torch.cat(way, dim=0) for way in support_images], dim=0) suffix = 'scribble' if _config['scribble'] else 'mask' if _config['bbox']: support_fg_mask = [] support_bg_mask = [] for i, way in enumerate(sample_batched['support_mask']): fg_masks = [] bg_masks = [] for j, shot in enumerate(way): fg_mask, bg_mask = get_bbox( shot['fg_mask'], sample_batched['support_inst'][i][j]) fg_masks.append(fg_mask.float().cuda()) bg_masks.append(bg_mask.float().cuda()) support_fg_mask.append(fg_masks) support_bg_mask.append(bg_masks) else: support_fg_mask = [[ shot[f'fg_mask'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_fg_mask = torch.cat( [torch.cat(way, dim=0) for way in support_fg_mask], dim=0) query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_images = torch.cat(query_images, dim=0) query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_pred = model(support_images, query_images, support_fg_mask) query_pred = F.interpolate(query_pred, size=query_images.shape[-2:], mode='bilinear') metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(query_labels[0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _log.info('----- Final Result -----') _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): torch.cuda.set_device(gpus[0]) # Network Builders net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout, use_softmax=True) crit = nn.NLLLoss(ignore_index=255) net_objectness.cuda() net_objectness.eval() net_decoder.cuda() net_decoder.eval() print('###### Prepare data ######') data_name = cfg.DATASET.name if data_name == 'VOC': if cfg.VAL.test_with_classes: from dataloaders.customized import voc_fewshot else: from dataloaders.customized_objectness import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': if cfg.VAL.test_with_classes: from dataloaders.customized import coco_fewshot else: from dataloaders.customized_objectness import coco_fewshot make_data = coco_fewshot max_label = 80 split = cfg.DATASET.data_split + '2014' annFile = f'{cfg.DATASET.data_dir}/annotations/instances_{split}.json' cocoapi = COCO(annFile) else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] #labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] #transforms = [Resize_test(size=cfg.DATASET.input_size)] val_transforms = [ transforms.ToNumpy(), transforms.Resize_pad(size=cfg.DATASET.input_size[0]) ] value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] '''val_transforms = [ transforms.ToNumpy(), #transforms.RandScale([0.9, 1.1]), #transforms.RandRotate([-10, 10], padding=mean, ignore_label=0), #transforms.RandomGaussianBlur(), #transforms.RandomHorizontalFlip(), transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0)]''' val_transforms = Compose(val_transforms) print('###### Testing begins ######') metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load data ###') dataset = make_data( base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=val_transforms, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, ) if data_name == 'COCO': coco_cls_ids = dataset.datasets[0].dataset.coco.getCatIds() testloader = DataLoader(dataset, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of Data: {len(dataset)}") count = 0 for sample_batched in tqdm.tqdm(testloader): feed_dict = data_preprocess(sample_batched, cfg) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) query_pred = net_decoder(feat, segSize=(473, 473)) metric.record(np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) if cfg.VAL.visualize: #print(as_numpy(feed_dict['seg_label'][0].cpu()).shape) #print(as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())).shape) #print(feed_dict['img_data'].cpu().shape) query_name = sample_batched['query_ids'][0][0] support_name = sample_batched['support_ids'][0][0][0] if data_name == 'VOC': img = imread( os.path.join(cfg.DATASET.data_dir, 'JPEGImages', query_name + '.jpg')) else: query_name = int(query_name) img_meta = cocoapi.loadImgs(query_name)[0] img = imread( os.path.join(cfg.DATASET.data_dir, split, img_meta['file_name'])) #img = imresize(img, cfg.DATASET.input_size) visualize_result( (img, as_numpy(feed_dict['seg_label'][0].cpu()), '%05d' % (count)), as_numpy(np.array(query_pred.argmax(dim=1)[0].cpu())), os.path.join(cfg.DIR, 'result')) count += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) '''_run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}')''' classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Final Result -----') print('final_classIoU', classIoU.tolist()) print('final_classIoU_std', classIoU_std.tolist()) print('final_meanIoU', meanIoU.tolist()) print('final_meanIoU_std', meanIoU_std.tolist()) print('final_classIoU_binary', classIoU_binary.tolist()) print('final_classIoU_std_binary', classIoU_std_binary.tolist()) print('final_meanIoU_binary', meanIoU_binary.tolist()) print('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) print(f'classIoU mean: {classIoU}') print(f'classIoU std: {classIoU_std}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'classIoU_binary mean: {classIoU_binary}') print(f'classIoU_binary std: {classIoU_std_binary}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}')
def main(cfg, gpus): # Network Builders torch.cuda.set_device(gpus[0]) print('###### Create model ######') net_objectness = ModelBuilder.build_objectness( arch=cfg.MODEL.arch_objectness, weights=cfg.MODEL.weights_enc_query, fix_encoder=cfg.TRAIN.fix_encoder) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), input_dim=cfg.MODEL.decoder_dim, fc_dim=cfg.MODEL.fc_dim, ppm_dim=cfg.MODEL.ppm_dim, num_class=2, weights=cfg.MODEL.weights_decoder, dropout_rate=cfg.MODEL.dropout_rate, use_dropout=cfg.MODEL.use_dropout) crit = nn.NLLLoss(ignore_index=255) print('###### Load data ######') data_name = cfg.DATASET.name if data_name == 'VOC': from dataloaders.customized_objectness import voc_fewshot make_data = voc_fewshot max_label = 20 elif data_name == 'COCO': from dataloaders.customized_objectness import coco_fewshot make_data = coco_fewshot max_label = 80 else: raise ValueError('Wrong config for dataset!') labels = CLASS_LABELS[data_name][cfg.TASK.fold_idx] labels_val = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ cfg.TASK.fold_idx] exclude_labels = labels_val value_scale = 255 mean = [0.485, 0.456, 0.406] mean = [item * value_scale for item in mean] std = [0.229, 0.224, 0.225] std = [item * value_scale for item in std] train_transform = [ transforms.ToNumpy(), transforms.RandScale([0.9, 1.1]), transforms.RandRotate([-10, 10], padding=mean, ignore_label=0), transforms.RandomGaussianBlur(), transforms.RandomHorizontalFlip(), transforms.Crop([cfg.DATASET.input_size[0], cfg.DATASET.input_size[1]], crop_type='rand', padding=mean, ignore_label=0) ] train_transform = Compose(train_transform) val_transform = Compose([ transforms.ToNumpy(), transforms.Resize_pad(size=cfg.DATASET.input_size[0]) ]) dataset = make_data(base_dir=cfg.DATASET.data_dir, split=cfg.DATASET.data_split, transforms=train_transform, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels, max_iters=cfg.TRAIN.n_iters * cfg.TRAIN.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.TRAIN.permute_labels, exclude_labels=exclude_labels) trainloader = DataLoader(dataset, batch_size=cfg.TRAIN.n_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) #segmentation_module = nn.DataParallel(segmentation_module, device_ids=gpus) net_objectness.cuda() net_decoder.cuda() # Set up optimizers nets = (net_objectness, net_decoder, crit) optimizers = create_optimizers(nets, cfg) batch_time = AverageMeter() data_time = AverageMeter() ave_total_loss = AverageMeter() ave_acc = AverageMeter() history = {'train': {'iter': [], 'loss': [], 'acc': []}} net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) best_iou = 0 # main loop tic = time.time() print('###### Training ######') for i_iter, sample_batched in enumerate(trainloader): # Prepare input feed_dict = data_preprocess(sample_batched, cfg) data_time.update(time.time() - tic) net_objectness.zero_grad() net_decoder.zero_grad() # adjust learning rate adjust_learning_rate(optimizers, i_iter, cfg) # forward pass feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) pred = net_decoder(feat) loss = crit(pred, feed_dict['seg_label']) acc = pixel_acc(pred, feed_dict['seg_label']) loss = loss.mean() acc = acc.mean() # Backward loss.backward() for optimizer in optimizers: if optimizer: optimizer.step() # measure elapsed time batch_time.update(time.time() - tic) tic = time.time() # update average loss and acc ave_total_loss.update(loss.data.item()) ave_acc.update(acc.data.item() * 100) # calculate accuracy, and display if i_iter % cfg.TRAIN.disp_iter == 0: print('Iter: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}'.format( i_iter, i_iter, cfg.TRAIN.n_iters, batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) history['train']['iter'].append(i_iter) history['train']['loss'].append(loss.data.item()) history['train']['acc'].append(acc.data.item()) if (i_iter + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, i_iter + 1) if (i_iter + 1) % cfg.TRAIN.eval_freq == 0: metric = Metric(max_label=max_label, n_runs=cfg.VAL.n_runs) with torch.no_grad(): print('----Evaluation----') net_objectness.eval() net_decoder.eval() net_decoder.use_softmax = True for run in range(cfg.VAL.n_runs): print(f'### Run {run + 1} ###') set_seed(cfg.VAL.seed + run) print(f'### Load validation data ###') dataset_val = make_data( base_dir=cfg.DATASET.data_dir, split='val', transforms=val_transform, to_tensor=transforms.ToTensorNormalize_noresize(), labels=labels_val, max_iters=cfg.VAL.n_iters * cfg.VAL.n_batch, n_ways=cfg.TASK.n_ways, n_shots=cfg.TASK.n_shots, n_queries=cfg.TASK.n_queries, permute=cfg.VAL.permute_labels, exclude_labels=[]) if data_name == 'COCO': coco_cls_ids = dataset_val.datasets[ 0].dataset.coco.getCatIds() testloader = DataLoader(dataset_val, batch_size=cfg.VAL.n_batch, shuffle=False, num_workers=1, pin_memory=True, drop_last=False) print(f"Total # of validation Data: {len(dataset)}") #for sample_batched in tqdm.tqdm(testloader): for sample_batched in testloader: feed_dict = data_preprocess(sample_batched, cfg, is_val=True) if data_name == 'COCO': label_ids = [ coco_cls_ids.index(x) + 1 for x in sample_batched['class_ids'] ] else: label_ids = list(sample_batched['class_ids']) feat = net_objectness(feed_dict['img_data'], return_feature_maps=True) query_pred = net_decoder( feat, segSize=cfg.DATASET.input_size) metric.record( np.array(query_pred.argmax(dim=1)[0].cpu()), np.array(feed_dict['seg_label'][0].cpu()), labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU( labels=sorted(labels_val), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels_val)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) print('----- Evaluation Result -----') print(f'best meanIoU_binary: {best_iou}') print(f'meanIoU mean: {meanIoU}') print(f'meanIoU std: {meanIoU_std}') print(f'meanIoU_binary mean: {meanIoU_binary}') print(f'meanIoU_binary std: {meanIoU_std_binary}') if meanIoU_binary > best_iou: best_iou = meanIoU_binary checkpoint(nets, history, cfg, 'best') net_objectness.train(not cfg.TRAIN.fix_bn) net_decoder.train(not cfg.TRAIN.fix_bn) net_decoder.use_softmax = False print('Training Done!')
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) device = torch.device(f"cuda:{_config['gpu_id']}") _log.info('###### Create model ######') resize_dim = _config['input_size'] encoded_h = int(resize_dim[0] / 2**_config['n_pool']) encoded_w = int(resize_dim[1] / 2**_config['n_pool']) s_encoder = SupportEncoder(_config['path']['init_path'], device) #.to(device) q_encoder = QueryEncoder(_config['path']['init_path'], device) #.to(device) decoder = Decoder(input_res=(encoded_h, encoded_w), output_res=resize_dim).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') s_encoder.load_state_dict(checkpoint['s_encoder']) q_encoder.load_state_dict(checkpoint['q_encoder']) decoder.load_state_dict(checkpoint['decoder']) # initializer.eval() # encoder.eval() # convlstmcell.eval() # decoder.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] s_x = sample_test['s_x'].to(device) # [B, slice_num, 1, 256, 256] s_y = sample_test['s_y'].to(device) # [B, slice_num, 1, 256, 256] q_x = sample_test['q_x'].to(device) # [B, slice_num, 1, 256, 256] q_y = sample_test['q_y'].to(device) # [B, slice_num, 1, 256, 256] s_fname = sample_test['s_fname'] q_fname = sample_test['q_fname'] s_xi = s_x[:, 0, :, :, :] #[B, 1, 256, 256] s_yi = s_y[:, 0, :, :, :] s_xi_encode, _ = s_encoder(s_xi, s_yi) #[B, 512, w, h] q_xi = q_x[:, 0, :, :, :] q_yi = q_y[:, 0, :, :, :] q_xi_encode, q_ft_list = q_encoder(q_xi) sq_xi = torch.cat((s_xi_encode, q_xi_encode), dim=1) yhati = decoder(sq_xi, q_ft_list) # [B, 1, 256, 256] preds.append(yhati.round()) img_list.append(q_xi[batch_i].cpu().numpy()) pred_list.append(yhati[batch_i].round().cpu().numpy()) label_list.append(q_yi[batch_i].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) q_fname_split = q_fname[0][0].split("/") q_fname_split[-6] = "Training_2d_2_pred" try_mkdirs("/".join(q_fname_split[:-1])) o_q_fname = "/".join(q_fname_split) np.save(o_q_fname, yhati.round().cpu().numpy()) # print(q_fname[0][0]) # print(o_q_fname) try_mkdirs("figure") print("start computing dice similarities ... total ", len(saves)) for subj_idx in range(len(saves)): save_subj = saves[subj_idx] dices = [] for slice_idx in range(len(save_subj)): subj_idx, idx, img_list, pred_list, label_list = save_subj[ slice_idx] for j in range(len(img_list)): dice = np.sum([label_list[j] * pred_list[j]]) * 2.0 / ( np.sum(pred_list[j]) + np.sum(label_list[j])) dices.append(dice) plt.clf() plt.bar([k for k in range(len(dices))], dices) plt.savefig(f"figure/bar_{_config['target']}_{subj_idx}.png")
parser.add_argument('--kappa', type=float, default=0., help='min margin in logits adv loss') parser.add_argument('--attack_lr', type=float, default=1e-2, help='lr in CW optimization') parser.add_argument('--binary_step', type=int, default=10, metavar='N', help='Binary search step') parser.add_argument('--num_iter', type=int, default=500, metavar='N', help='Number of iterations in each search step') parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') args = parser.parse_args() BATCH_SIZE = BATCH_SIZE[args.num_points] BEST_WEIGHTS = BEST_WEIGHTS[args.dataset][args.num_points] if args.batch_size == -1: args.batch_size = BATCH_SIZE[args.model] set_seed(1) print(args) dist.init_process_group(backend='nccl') torch.cuda.set_device(args.local_rank) cudnn.benchmark = True # build model if args.model.lower() == 'dgcnn': model = DGCNN(args.emb_dims, args.k, output_channels=40) elif args.model.lower() == 'pointnet': model = PointNetCls(k=40, feature_transform=args.feature_transform) elif args.model.lower() == 'pointnet2': model = PointNet2ClsSsg(num_classes=40) elif args.model.lower() == 'pointconv': model = PointConvDensityClsSsg(num_classes=40)
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) device = torch.device(f"cuda:{_config['gpu_id']}") _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data q_slice_n = _config['q_slice'] iter_print = _config['iter_print'] if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) if _config["n_update"]: _log.info('###### fine tuning with support data of target organ #####') _config["n_shot"] = _config["n_shot"] - 1 _log.info('###### Create model ######') model = MedicalFSS(_config, device).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') print("checkpoint keys : ", checkpoint.keys()) # initializer.load_state_dict(checkpoint['initializer']) model.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) tr_dataset, val_dataset, ts_dataset = make_data(_config, is_finetuning=True) trainloader = DataLoader(dataset=tr_dataset, batch_size=1, shuffle=False, pin_memory=False, drop_last=False) optimizer = torch.optim.Adam(list(model.parameters()), _config['optim']['lr']) # optimizer = torch.optim.SGD(list(model.parameters()),1e-5) # criterion = nn.BCELoss() criterion = losses.DiceLoss() criterion_ce = nn.CrossEntropyLoss() for i_iter, sample_train in enumerate(trainloader): preds = [] loss_per_video = 0.0 optimizer.zero_grad() s_x = sample_train['s_x'].to( device) # [B, Support, slice_num, 1, 256, 256] s_y = sample_train['s_y'].to( device) # [B, Support, slice_num, 1, 256, 256] q_x = sample_train['q_x'].to(device) # [B, slice_num, 1, 256, 256] q_y = sample_train['q_y'].type(torch.LongTensor).to( device) #[B, slice_num, 1, 256, 256] preds = model(s_x, s_y, q_x) for frame_id in range(q_slice_n): q_yi = q_y[:, frame_id, :, :, :] # [B, 1, 256, 256] q_yi2 = q_yi.squeeze(1) # [B, 256, 256] yhati = preds[frame_id] # pdb.set_trace() # loss = criterion(F.softmax(yhati, dim=1), q_yi2) loss = criterion(F.softmax(yhati, dim=1), q_yi2) + criterion_ce( F.softmax(yhati, dim=1), q_yi2) loss_per_video += loss preds.append(yhati) loss_per_video.backward() optimizer.step() if iter_print: print( f"train, iter:{i_iter}/{_config['n_update']}, iter_loss:{loss_per_video}", end='\r') _config["n_shot"] = _config["n_shot"] + 1 else: _log.info('###### Create model ######') model = MedicalFSS(_config, device).to(device) checkpoint = torch.load(_config['snapshot'], map_location='cpu') print("checkpoint keys : ", checkpoint.keys()) # initializer.load_state_dict(checkpoint['initializer']) model.load_state_dict(checkpoint['model']) model.n_shot = _config["n_shot"] tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader(dataset=ts_dataset, batch_size=1, shuffle=False, pin_memory=False, drop_last=False) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) blank = torch.zeros([1, 256, 256]).to(device) reversed_idx = list(reversed(range(q_slice_n))) ch = 256 # number of channels of embedding img_lists = [] pred_lists = [] label_lists = [] saves = {} n_test = len(ts_dataset.q_cnts) for subj_idx in range(n_test): saves[subj_idx] = [] with torch.no_grad(): batch_idx = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list, pred_list, label_list, preds = [], [], [], [] s_x = sample_test['s_x'].to(device) # [B, slice_num, 1, 256, 256] s_y = sample_test['s_y'].to(device) # [B, slice_num, 1, 256, 256] q_x = sample_test['q_x'].to(device) # [B, slice_num, 1, 256, 256] q_y = sample_test['q_y'].to(device) # [B, slice_num, 1, 256, 256] fnames = sample_test['q_fname'] preds = model(s_x, s_y, q_x) for frame_id in range(q_slice_n): q_xi = q_x[:, frame_id, :, :, :] q_yi = q_y[:, frame_id, :, :, :] yhati = preds[frame_id] preds.append(yhati.argmax(dim=1)) img_list.append(q_xi[batch_idx].cpu().numpy()) pred_list.append(yhati.argmax(dim=1).cpu().numpy()) label_list.append(q_yi[batch_idx].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list, fnames]) if iter_print: print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) # if _config['record']: # frames = [] # for frame_id in range(0, q_x.size(1)): # frames += overlay_color(q_x[batch_idx, frame_id], preds[frame_id-1][batch_idx].round(), q_y[batch_idx, frame_id], scale=_config['scale']) # visual = make_grid(frames, normalize=True, nrow=5) # writer.add_image(f"test/{subj_idx}/{idx}_query_image", visual, i) center_idx = (q_slice_n // 2) + 1 - 1 # 5->2 index dice_similarities = [] for subj_idx in range(n_test): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[ i] # if idx==(q_slice_n//2): if idx == 0: for j in range((q_slice_n // 2) + 1): # 5//2 + 1 = 3 imgs.append(img_list[idx + j]) preds.append(pred_list[idx + j]) labels.append(label_list[idx + j]) elif idx == (len(save_subj) - 1): # pdb.set_trace() for j in range((q_slice_n // 2) + 1): # 5//2 + 1 = 3 imgs.append(img_list[center_idx + j]) preds.append(pred_list[center_idx + j]) labels.append(label_list[center_idx + j]) else: imgs.append(img_list[center_idx]) preds.append(pred_list[center_idx]) labels.append(label_list[center_idx]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) dice = np.sum([label_arr * pred_arr ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) ## IoU union = np.clip(pred_arr + label_arr, 0, 1) IoU = np.sum([label_arr * pred_arr]) / np.sum(union) dice_similarities.append(dice) print( f"{len(imgs)} slice -> computing dice scores. {subj_idx}/{n_test}. {ts_dataset.q_cnts[subj_idx] }/{len(save_subj)} => {len(imgs)}", end='\r') if _config['record']: frames = [] for frame_id in range(0, len(imgs)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]), torch.tensor(labels[frame_id]), scale=_config['scale']) print(len(frames)) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) if _config['save_sample']: ## only for internal test (BCV - MICCAI2015) sup_idx = _config['s_idx'] target = _config['target'] save_name = _config['save_name'] dirs = ["gt", "pred", "input"] save_dir = f"../sample/bigru_organ{target}_sup{sup_idx}_{save_name}" for dir in dirs: try: os.makedirs(os.path.join(save_dir, dir)) except: pass subj_name = fnames[0][0].split("/")[-2] if target == 14: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img" orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz" pass else: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img" orig_fname = f"{src_dir}/img{subj_name}.nii.gz" itk = sitk.ReadImage(orig_fname) orig_spacing = itk.GetSpacing() label_arr = label_arr * 2.0 # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])]) # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])]) # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])]) itk = sitk.GetImageFromArray(label_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(pred_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(img_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz") print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))
def main(_run, _config, _log): logdir = f'{_run.observers[0].dir}/' print(logdir) category = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] if _run.observers: os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True) for source_file, _ in _run.experiment_info['sources']: os.makedirs(os.path.dirname( f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') data_name = _config['dataset'] max_label = 20 if data_name == 'VOC' else 80 set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) print(_config['ckpt_dir']) tbwriter = SummaryWriter(osp.join(_config['ckpt_dir'])) training_tags = { 'loss': "ATraining/total_loss", "query_loss": "ATraining/query_loss", 'aligned_loss': "ATraining/aligned_loss", 'base_loss': "ATraining/base_loss", } infer_tags = { 'mean_iou': "MeanIoU/mean_iou", "mean_iou_binary": "MeanIoU/mean_iou_binary", } _log.info('###### Create model ######') if _config['model']['part']: model = FewshotSegPartResnet( pretrained_path=_config['path']['init_path'], cfg=_config) _log.info('Model: FewshotSegPartResnet') else: model = FewshotSegResnet(pretrained_path=_config['path']['init_path'], cfg=_config) _log.info('Model: FewshotSegResnet') model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) model.train() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = voc_fewshot labels = CLASS_LABELS[data_name][_config['label_sets']] transforms = Compose([Resize(size=_config['input_size']), RandomMirror()]) dataset = make_data(base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['n_steps'] * _config['batch_size'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'], n_unlabel=_config['task']['n_unlabels'], cfg=_config) trainloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=True, num_workers=_config['num_workers'], pin_memory=True, drop_last=True) _log.info('###### Set optimizer ######') if _config['fix']: print('Optimizer: fix') optimizer = torch.optim.SGD( params=[ { "params": model.module.encoder.layer3.parameters(), "lr": _config['optim']['lr'], "weight_decay": _config['optim']['weight_decay'] }, { "params": model.module.encoder.layer4.parameters(), "lr": _config['optim']['lr'], "weight_decay": _config['optim']['weight_decay'] }, ], momentum=_config['optim']['momentum'], ) else: print('Optimizer: Not fix') optimizer = torch.optim.SGD(model.parameters(), **_config['optim']) scheduler = MultiStepLR(optimizer, milestones=_config['lr_milestones'], gamma=0.1) criterion = nn.CrossEntropyLoss(ignore_index=_config['ignore_label']) log_loss = {'loss': 0, 'align_loss': 0, 'base_loss': 0} _log.info('###### Training ######') highest_iou = 0 metrics = {} for i_iter, sample_batched in enumerate(trainloader): if _config['fix']: model.module.encoder.conv1.eval() model.module.encoder.bn1.eval() model.module.encoder.layer1.eval() model.module.encoder.layer2.eval() if _config['eval']: if i_iter == 0: break # Prepare input support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] support_fg_mask = [[shot[f'fg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] support_bg_mask = [[shot[f'bg_mask'].float().cuda() for shot in way] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.long().cuda() for query_label in sample_batched['query_labels'] ], dim=0) #1*417*417 base_loss = torch.zeros(1).to(torch.device('cuda')) # Forward and Backward optimizer.zero_grad() query_pred, _, align_loss = model(support_images, support_fg_mask, support_bg_mask, query_images) query_loss = criterion(query_pred, query_labels) #1*3*417*417, 1*417*417 loss = query_loss + align_loss * _config[ 'align_loss_scaler'] + base_loss * _config['base_loss_scaler'] loss.backward() optimizer.step() scheduler.step() # Log loss query_loss = query_loss.detach().data.cpu().numpy() align_loss = align_loss.detach().data.cpu().numpy() base_loss = base_loss.detach().data.cpu().numpy() log_loss['loss'] += query_loss log_loss['align_loss'] += align_loss log_loss['base_loss'] += base_loss # print loss and take snapshots if (i_iter + 1) % _config['print_interval'] == 0: loss = log_loss['loss'] / (i_iter + 1) align_loss = log_loss['align_loss'] / (i_iter + 1) base_loss = log_loss['base_loss'] / (i_iter + 1) print( f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}' ) _log.info( f'step {i_iter+1}: loss: {loss}, align_loss: {align_loss}, base_loss: {base_loss}' ) metrics['loss'] = loss metrics['query_loss'] = query_loss metrics['aligned_loss'] = align_loss metrics['base_loss'] = base_loss # for k, v in metrics.items(): # tbwriter.add_scalar(training_tags[k], v, i_iter) if (i_iter + 1) % _config['evaluate_interval'] == 0: _log.info('###### Evaluation begins ######') print(_config['ckpt_dir']) model.eval() labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) metric = Metric(max_label=max_label, n_runs=_config['n_runs']) with torch.no_grad(): for run in range(1): _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['infer_max_iters'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'], n_unlabel=_config['task']['n_unlabels'], cfg=_config) testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=_config['num_workers'], pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): label_ids = list(sample_batched['class_ids']) support_images = [[ shot.cuda() for shot in way ] for way in sample_batched['support_images']] suffix = 'mask' support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) curr_iou = metric.record(query_pred.argmax(dim=1)[0], query_labels[0], labels=label_ids, n_run=run) classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary( n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') print( f'meanIoU: {meanIoU}, meanIoU_binary: {meanIoU_binary}' ) metrics = {} metrics['mean_iou'] = meanIoU metrics['mean_iou_binary'] = meanIoU_binary for k, v in metrics.items(): tbwriter.add_scalar(infer_tags[k], v, i_iter) if meanIoU > highest_iou: print( f'The highest iou is in iter: {i_iter} : {meanIoU}, save: {_config["ckpt_dir"]}/best.pth' ) highest_iou = meanIoU torch.save( model.state_dict(), os.path.join(f'{_config["ckpt_dir"]}/best.pth')) else: print( f'The highest iou is in iter: {i_iter} : {meanIoU}' ) torch.save(model.state_dict(), os.path.join(f'{_config["ckpt_dir"]}/{i_iter + 1}.pth')) model.train() print(_config['ckpt_dir']) _log.info(' --------- Testing begins ---------') labels = CLASS_LABELS[data_name]['all'] - CLASS_LABELS[data_name][ _config['label_sets']] transforms = [Resize(size=_config['input_size'])] transforms = Compose(transforms) ckpt = os.path.join(f'{_config["ckpt_dir"]}/best.pth') print(f'{_config["ckpt_dir"]}/best.pth') model.load_state_dict(torch.load(ckpt, map_location='cpu')) model.eval() metric = Metric(max_label=max_label, n_runs=5) with torch.no_grad(): for run in range(5): n_iter = 0 _log.info(f'### Run {run + 1} ###') set_seed(_config['seed'] + run) _log.info(f'### Load data ###') dataset = make_data( base_dir=_config['path'][data_name]['data_dir'], split=_config['path'][data_name]['data_split'], transforms=transforms, to_tensor=ToTensorNormalize(), labels=labels, max_iters=_config['infer_max_iters'], n_ways=_config['task']['n_ways'], n_shots=_config['task']['n_shots'], n_queries=_config['task']['n_queries'], n_unlabel=_config['task']['n_unlabels'], cfg=_config) testloader = DataLoader(dataset, batch_size=_config['batch_size'], shuffle=False, num_workers=_config['num_workers'], pin_memory=True, drop_last=False) _log.info(f"Total # of Data: {len(dataset)}") for sample_batched in tqdm.tqdm(testloader): label_ids = list(sample_batched['class_ids']) support_images = [[shot.cuda() for shot in way] for way in sample_batched['support_images']] suffix = 'mask' support_fg_mask = [[ shot[f'fg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] support_bg_mask = [[ shot[f'bg_{suffix}'].float().cuda() for shot in way ] for way in sample_batched['support_mask']] query_images = [ query_image.cuda() for query_image in sample_batched['query_images'] ] query_labels = torch.cat([ query_label.cuda() for query_label in sample_batched['query_labels'] ], dim=0) query_pred, _, _ = model(support_images, support_fg_mask, support_bg_mask, query_images) curr_iou = metric.record(query_pred.argmax(dim=1)[0], query_labels[0], labels=label_ids, n_run=run) n_iter += 1 classIoU, meanIoU = metric.get_mIoU(labels=sorted(labels), n_run=run) classIoU_binary, meanIoU_binary = metric.get_mIoU_binary(n_run=run) _run.log_scalar('classIoU', classIoU.tolist()) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _log.info(f'classIoU: {classIoU}') _log.info(f'meanIoU: {meanIoU}') _log.info(f'classIoU_binary: {classIoU_binary}') _log.info(f'meanIoU_binary: {meanIoU_binary}') classIoU, classIoU_std, meanIoU, meanIoU_std = metric.get_mIoU( labels=sorted(labels)) classIoU_binary, classIoU_std_binary, meanIoU_binary, meanIoU_std_binary = metric.get_mIoU_binary( ) _run.log_scalar('meanIoU', meanIoU.tolist()) _run.log_scalar('meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_classIoU', classIoU.tolist()) _run.log_scalar('final_classIoU_std', classIoU_std.tolist()) _run.log_scalar('final_meanIoU', meanIoU.tolist()) _run.log_scalar('final_meanIoU_std', meanIoU_std.tolist()) _run.log_scalar('final_classIoU_binary', classIoU_binary.tolist()) _run.log_scalar('final_classIoU_std_binary', classIoU_std_binary.tolist()) _run.log_scalar('final_meanIoU_binary', meanIoU_binary.tolist()) _run.log_scalar('final_meanIoU_std_binary', meanIoU_std_binary.tolist()) _log.info('----- Final Result -----') _log.info(f'classIoU mean: {classIoU}') _log.info(f'classIoU std: {classIoU_std}') _log.info(f'meanIoU mean: {meanIoU}') _log.info(f'meanIoU std: {meanIoU_std}') _log.info(f'classIoU_binary mean: {classIoU_binary}') _log.info(f'classIoU_binary std: {classIoU_std_binary}') _log.info(f'meanIoU_binary mean: {meanIoU_binary}') _log.info(f'meanIoU_binary std: {meanIoU_std_binary}') _log.info("## ------------------------------------------ ##") _log.info(f'###### Setting: {_run.observers[0].dir} ######') _log.info( "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} " "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}". format(num_run=5, miou=meanIoU, mbiou=meanIoU_binary, miou_std=meanIoU_std, mbiou_std=meanIoU_std_binary)) _log.info(f"Current setting is {_run.observers[0].dir}") print( "Running {num_run} runs, meanIoU:{miou:.4f}, meanIoU_binary:{mbiou:.4f} " "meanIoU_std:{miou_std:.4f}, meanIoU_binary_std:{mbiou_std:.4f}". format(num_run=5, miou=meanIoU, mbiou=meanIoU_binary, miou_std=meanIoU_std, mbiou_std=meanIoU_std_binary)) print(f"Current setting is {_run.observers[0].dir}") print(_config['ckpt_dir']) print(logdir)
def main(_run, _config, _log): for source_file, _ in _run.experiment_info['sources']: os.makedirs( os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'), exist_ok=True) _run.observers[0].save_file(source_file, f'source/{source_file}') shutil.rmtree(f'{_run.observers[0].basedir}/_sources') set_seed(_config['seed']) cudnn.enabled = True cudnn.benchmark = True torch.cuda.set_device(device=_config['gpu_id']) torch.set_num_threads(1) _log.info('###### Create model ######') model = FewShotSeg(pretrained_path=_config['path']['init_path'], cfg=_config['model']) model = nn.DataParallel(model.cuda(), device_ids=[ _config['gpu_id'], ]) if not _config['notrain']: model.load_state_dict( torch.load(_config['snapshot'], map_location='cpu')) model.eval() _log.info('###### Load data ######') data_name = _config['dataset'] make_data = meta_data max_label = 1 tr_dataset, val_dataset, ts_dataset = make_data(_config) testloader = DataLoader( dataset=ts_dataset, batch_size=1, shuffle=False, # num_workers=_config['n_work'], pin_memory=False, # True drop_last=False) if _config['record']: _log.info('###### define tensorboard writer #####') board_name = f'board/test_{_config["board"]}_{date()}' writer = SummaryWriter(board_name) _log.info('###### Testing begins ######') # metric = Metric(max_label=max_label, n_runs=_config['n_runs']) img_cnt = 0 # length = len(all_samples) length = len(testloader) img_lists = [] pred_lists = [] label_lists = [] saves = {} for subj_idx in range(len(ts_dataset.get_cnts())): saves[subj_idx] = [] with torch.no_grad(): loss_valid = 0 batch_i = 0 # use only 1 batch size for testing for i, sample_test in enumerate( testloader): # even for upward, down for downward subj_idx, idx = ts_dataset.get_test_subj_idx(i) img_list = [] pred_list = [] label_list = [] preds = [] fnames = sample_test['q_fname'] s_x_orig = sample_test['s_x'].cuda( ) # [B, Support, slice_num=1, 1, 256, 256] s_x = s_x_orig.squeeze(2) # [B, Support, 1, 256, 256] s_y_fg_orig = sample_test['s_y'].cuda( ) # [B, Support, slice_num, 1, 256, 256] s_y_fg = s_y_fg_orig.squeeze(2) # [B, Support, 1, 256, 256] s_y_fg = s_y_fg.squeeze(2) # [B, Support, 256, 256] s_y_bg = torch.ones_like(s_y_fg) - s_y_fg q_x_orig = sample_test['q_x'].cuda() # [B, slice_num, 1, 256, 256] q_x = q_x_orig.squeeze(1) # [B, 1, 256, 256] q_y_orig = sample_test['q_y'].cuda() # [B, slice_num, 1, 256, 256] q_y = q_y_orig.squeeze(1) # [B, 1, 256, 256] q_y = q_y.squeeze(1).long() # [B, 256, 256] s_xs = [[s_x[:, shot, ...] for shot in range(_config["n_shot"])]] s_y_fgs = [[ s_y_fg[:, shot, ...] for shot in range(_config["n_shot"]) ]] s_y_bgs = [[ s_y_bg[:, shot, ...] for shot in range(_config["n_shot"]) ]] q_xs = [q_x] q_yhat, align_loss = model(s_xs, s_y_fgs, s_y_bgs, q_xs) # q_yhat = q_yhat[:,1:2, ...] q_yhat = q_yhat.argmax(dim=1) q_yhat = q_yhat.unsqueeze(1) preds.append(q_yhat) img_list.append(q_x_orig[batch_i, 0].cpu().numpy()) pred_list.append(q_yhat[batch_i].cpu().numpy()) label_list.append(q_y_orig[batch_i, 0].cpu().numpy()) saves[subj_idx].append( [subj_idx, idx, img_list, pred_list, label_list, fnames]) print(f"test, iter:{i}/{length} - {subj_idx}/{idx} \t\t", end='\r') img_lists.append(img_list) pred_lists.append(pred_list) label_lists.append(label_list) print("start computing dice similarities ... total ", len(saves)) dice_similarities = [] for subj_idx in range(len(saves)): imgs, preds, labels = [], [], [] save_subj = saves[subj_idx] for i in range(len(save_subj)): # print(len(save_subj), len(save_subj)-q_slice_n+1, q_slice_n, i) subj_idx, idx, img_list, pred_list, label_list, fnames = save_subj[ i] # print(subj_idx, idx, is_reverse, len(img_list)) # print(i, is_reverse, is_reverse_next, is_flip) for j in range(len(img_list)): imgs.append(img_list[j]) preds.append(pred_list[j]) labels.append(label_list[j]) # pdb.set_trace() img_arr = np.concatenate(imgs, axis=0) pred_arr = np.concatenate(preds, axis=0) label_arr = np.concatenate(labels, axis=0) # pdb.set_trace() # print(ts_dataset.slice_cnts[subj_idx] , len(imgs)) # pdb.set_trace() dice = np.sum([label_arr * pred_arr ]) * 2.0 / (np.sum(pred_arr) + np.sum(label_arr)) dice_similarities.append(dice) print(f"computing dice scores {subj_idx}/{10}", end='\n') if _config['record']: frames = [] for frame_id in range(0, len(save_subj)): frames += overlay_color(torch.tensor(imgs[frame_id]), torch.tensor(preds[frame_id]).float(), torch.tensor(labels[frame_id])) visual = make_grid(frames, normalize=True, nrow=5) writer.add_image(f"test/{subj_idx}", visual, i) writer.add_scalar(f'dice_score/{i}', dice) if _config['save_sample']: ## only for internal test (BCV - MICCAI2015) sup_idx = _config['s_idx'] target = _config['target'] save_name = _config['save_name'] dirs = ["gt", "pred", "input"] save_dir = f"../sample/panet_organ{target}_sup{sup_idx}_{save_name}" for dir in dirs: try: os.makedirs(os.path.join(save_dir, dir)) except: pass subj_name = fnames[0][0].split("/")[-2] if target == 14: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Cervix/RawData/Training/img" orig_fname = f"{src_dir}/{subj_name}-Image.nii.gz" pass else: src_dir = "/user/home2/soopil/Datasets/MICCAI2015challenge/Abdomen/RawData/Training/img" orig_fname = f"{src_dir}/img{subj_name}.nii.gz" itk = sitk.ReadImage(orig_fname) orig_spacing = itk.GetSpacing() label_arr = label_arr * 2.0 # label_arr = np.concatenate([np.zeros([1,256,256]), label_arr,np.zeros([1,256,256])]) # pred_arr = np.concatenate([np.zeros([1,256,256]), pred_arr,np.zeros([1,256,256])]) # img_arr = np.concatenate([np.zeros([1,256,256]), img_arr,np.zeros([1,256,256])]) itk = sitk.GetImageFromArray(label_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/gt/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(pred_arr.astype(float)) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/pred/{subj_idx}.nii.gz") itk = sitk.GetImageFromArray(img_arr) itk.SetSpacing(orig_spacing) sitk.WriteImage(itk, f"{save_dir}/input/{subj_idx}.nii.gz") print(f"test result \n n : {len(dice_similarities)}, mean dice score : \ {np.mean(dice_similarities)} \n dice similarities : {dice_similarities}") if _config['record']: writer.add_scalar(f'dice_score/mean', np.mean(dice_similarities))