def get_sample(data): gdata = gBatch().from_data_list([data['points']]) gdata = gdata.to(DEVICE) gdata.batch = gdata.bvec.clone() del gdata.bvec gdata['lengths'] = gdata['lengths'][0].item() return gdata
def get_gbatch_sample(sample, sample_size, same_size, return_name=False): data_list = [] name_list = [] ori_batch = [] for i, d in enumerate(sample): if 'bvec' in d['points'].keys: d['points'].bvec += sample_size * i data_list.append(d['points']) name_list.append(d['name']) ori_batch.append([i] * sample_size) points = gBatch().from_data_list(data_list) points.ori_batch = torch.tensor(ori_batch).flatten().long() if 'bvec' in points.keys: #points.batch = points.bvec.copy() points.batch = points.bvec.clone() del points.bvec if same_size: points['lengths'] = points['lengths'][0].item() if return_name: return points, name_list return points
def val_iter(cfg, val_dataloader, classifier, writer, epoch, cluster_loss_fn, best_epoch, best_pred, logdir): num_classes = int(cfg['n_classes']) #batch_size = int(cfg['batch_size']) batch_size = 1 n_epochs = int(cfg['n_epochs']) sample_size = int(cfg['fixed_size']) input_size = int(cfg['data_dim']) num_batch = cfg['num_batch'] alfa = 0 ep_loss = 0. classifier.eval() with torch.no_grad(): pred_buffer = {} sm_buffer = {} sm2_buffer = {} gf_buffer = {} print('\n\n') mean_val_acc = torch.tensor([]) mean_val_iou = torch.tensor([]) mean_val_prec = torch.tensor([]) mean_val_recall = torch.tensor([]) mean_val_iou_c = torch.tensor([]) for j, data in enumerate(val_dataloader): if 'graph' not in cfg['dataset']: points = data['points'] target = data['gt'] points, target = Variable(points), Variable(target) points, target = points.cuda(), target.cuda() else: data_list = [] name_list = [] for i, d in enumerate(data): if 'bvec' in d['points'].keys: d['points'].bvec += sample_size * i data_list.append(d['points']) name_list.append(d['name']) points = gBatch().from_data_list(data_list) if 'bvec' in points.keys: points.batch = points.bvec.clone() del points.bvec target = points['y'] if cfg['same_size']: points['lengths'] = points['lengths'][0].item() data = {'points': points, 'gt': target, 'name': name_list} points, target = points.to('cuda'), target.to('cuda') sample_name = data['name'][0] logits = classifier(points) if len(cfg['loss']) == 2: if epoch <= int(cfg['switch_loss_epoch']): loss_type = cfg['loss'][0] else: loss_type = cfg['loss'][1] else: loss_type = cfg['loss'][0] if loss_type == 'nll': pred = F.log_softmax(logits, dim=-1) pred = pred.view(-1, num_classes) probas = torch.exp(pred.data) pred_choice = pred.data.max(1)[1].int() if cfg['nll_w']: ce_w = torch.tensor([1.5e-2] + [1.] * (num_classes - 1)).cuda() else: ce_w = torch.tensor([1.] * num_classes).cuda() #print(pred.shape, target.shape) loss_seg = F.nll_loss(pred, target.long(), weight=ce_w) elif loss_type == 'LLh': pred_choice = (logits.data > 0).int() loss_seg = L.lovasz_hinge(logits.view(batch_size, sample_size, 1), target.view(batch_size, sample_size, 1), per_image=False) elif loss_type == 'LLm': pred = F.softmax(logits, dim=-1) probas = pred.data pred_choice = pred.data.max(1)[1].int() loss_seg = L.lovasz_softmax_flat( pred, target, op=cfg['llm_op'], only_present=cfg['multi_category']) loss = loss_seg ep_loss += loss print('val max class pred ', pred_choice.max().item()) print('val min class pred ', pred_choice.min().item()) print('# class pred ', len(np.unique(pred_choice.cpu().numpy()))) correct = pred_choice.eq(target.data.int()).cpu().sum() acc = correct.item() / float(target.size(0)) tp = torch.mul(pred_choice.data, target.data.int()).cpu().sum().item() + 0.00001 fp = pred_choice.gt(target.data.int()).cpu().sum().item() fn = pred_choice.lt(target.data.int()).cpu().sum().item() tn = correct.item() - tp iou = torch.tensor([float(tp) / (tp + fp + fn)]) prec = torch.tensor([float(tp) / (tp + fp)]) recall = torch.tensor([float(tp) / (tp + fn)]) print('VALIDATION [%d: %d/%d] val loss: %f acc: %f iou: %f' % (epoch, j, len(val_dataloader), loss, acc, iou)) mean_val_prec = torch.cat((mean_val_prec, prec), 0) mean_val_recall = torch.cat((mean_val_recall, recall), 0) mean_val_iou = torch.cat((mean_val_iou, iou), 0) mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0) if cfg['save_pred']: sl_idx = np.where( pred_choice.data.cpu().view(-1).numpy() == 1)[0] pred_buffer[sample_name] = sl_idx.tolist() macro_iou = torch.mean(mean_val_iou) macro_prec = torch.mean(mean_val_prec) macro_recall = torch.mean(mean_val_recall) macro_iou_c = torch.mean(mean_val_iou_c) epoch_iou = macro_iou.item() writer.add_scalar('val/epoch_acc', torch.mean(mean_val_acc).item(), epoch) writer.add_scalar('val/epoch_iou', epoch_iou, epoch) writer.add_scalar('val/epoch_prec', macro_prec.item(), epoch) writer.add_scalar('val/epoch_recall', macro_recall.item(), epoch) writer.add_scalar('val/epoch_iou_c', macro_iou_c.item(), epoch) writer.add_scalar('val/loss', ep_loss / j, epoch) print('VALIDATION ACCURACY: %f' % torch.mean(mean_val_acc).item()) print('VALIDATION IOU: %f' % epoch_iou) print('VALIDATION IOUC: %f' % macro_iou_c.item()) print('\n\n') if epoch_iou > best_pred: best_pred = epoch_iou best_epoch = epoch if cfg['save_model']: modeldir = os.path.join(logdir, cfg['model_dir']) if not os.path.exists(modeldir): os.makedirs(modeldir) else: os.system('rm %s/best_model*.pth' % modeldir) torch.save( classifier.state_dict(), '%s/best_model_iou-%f_ep-%d.pth' % (modeldir, best_pred, epoch)) return best_epoch, best_pred, ep_loss
def train_iter(cfg, dataloader, classifier, optimizer, writer, epoch, n_iter, cluster_loss_fn): num_classes = int(cfg['n_classes']) batch_size = int(cfg['batch_size']) n_epochs = int(cfg['n_epochs']) sample_size = int(cfg['fixed_size']) input_size = int(cfg['data_dim']) num_batch = cfg['num_batch'] alfa = 0 ep_loss = 0. ep_seg_loss = 0. ep_cluster_loss = 0. mean_acc = torch.tensor([]) mean_iou = torch.tensor([]) mean_prec = torch.tensor([]) mean_recall = torch.tensor([]) ### state that the model will run in train mode classifier.train() #d_list=[] #for dat in dataloader: #for d in dat: #d_list.append(d) #points = gBatch().from_data_list(d_list) #target = points['y'] #name = dataset['name'] #points, target = points.to('cuda'), target.to('cuda') for i_batch, sample_batched in enumerate(dataloader): ### get batch if 'graph' not in cfg['dataset']: points = sample_batched['points'] target = sample_batched['gt'] #if cfg['model'] == 'pointnet_cls': #points = points.view(batch_size*sample_size, -1, input_size) #target = target.view(batch_size*sample_size, -1) #batch_size = batch_size*sample_size #sample_size = points.shape[1] points, target = Variable(points), Variable(target) points, target = points.cuda(), target.cuda() else: data_list = [] name_list = [] for i, d in enumerate(sample_batched): if 'bvec' in d['points'].keys: d['points'].bvec += sample_size * i data_list.append(d['points']) name_list.append(d['name']) points = gBatch().from_data_list(data_list) if 'bvec' in points.keys: #points.batch = points.bvec.copy() points.batch = points.bvec.clone() del points.bvec #if 'bslices' in points.keys(): # points.__slices__ = torch.cum( target = points['y'] if cfg['same_size']: points['lengths'] = points['lengths'][0].item() sample_batched = { 'points': points, 'gt': target, 'name': name_list } #print('points:',points) #if (epoch != 0) and (epoch % 20 == 0): # assert(len(dataloader.dataset) % int(cfg['fold_size']) == 0) # folds = len(dataloader.dataset)/int(cfg['fold_size']) # n_fold = (dataloader.dataset.n_fold + 1) % folds # if n_fold != dataloader.dataset.n_fold: # dataloader.dataset.n_fold = n_fold # dataloader.dataset.load_fold() points, target = points.to('cuda'), target.to('cuda') #print(len(points.lengths),target.shape) ### initialize gradients #if not cfg['accumulation_interval'] or i_batch == 0: optimizer.zero_grad() ### forward logits = classifier(points) ### minimize the loss if len(cfg['loss']) == 2: if epoch <= int(cfg['switch_loss_epoch']): loss_type = cfg['loss'][0] else: loss_type = cfg['loss'][1] else: loss_type = cfg['loss'][0] if loss_type == 'nll': pred = F.log_softmax(logits, dim=-1) pred = pred.view(-1, num_classes) pred_choice = pred.data.max(1)[1].int() if cfg['nll_w']: ce_w = torch.tensor([1.5e-2] + [1.] * (num_classes - 1)).cuda() else: ce_w = torch.tensor([1.] * num_classes).cuda() #print(pred.shape) loss = F.nll_loss(pred, target.long(), weight=ce_w) elif loss_type == 'LLh': pred_choice = (logits.data > 0).int() loss = L.lovasz_hinge(logits.view(batch_size, sample_size, 1), target.view(batch_size, sample_size, 1), per_image=False) elif loss_type == 'LLm': pred = F.softmax(logits, dim=-1) pred_choice = pred.data.max(1)[1].int() loss = L.lovasz_softmax_flat(pred, target, op=cfg['llm_op'], only_present=cfg['multi_category']) ep_loss += loss if cfg['print_bwgraph']: #with torch.onnx.set_training(classifier, False): # trace, _ = torch.jit.get_trace_graph(classifier, args=(points.transpose(2,1),)) #g = make_dot_from_trace(trace) from torchviz import make_dot, make_dot_from_trace g = make_dot(loss, params=dict(classifier.named_parameters())) # g = make_dot(loss, # params=None) g.view('pointnet_mgf') print('classifier parameters: %d' % int(count_parameters(classifier))) os.system('rm -r runs/%s' % writer.logdir.split('/', 1)[1]) os.system('rm -r tb_logs/%s' % writer.logdir.split('/', 1)[1]) import sys sys.exit() #print('memory allocated in MB: ', torch.cuda.memory_allocated()/2**20) #import sys; sys.exit() loss.backward() #if int(cfg['accumulation_interval']) % (i_batch+1) == 0: optimizer.step() #elif not cfg['accumulation_interval']: # optimizer.step() ### compute performance correct = pred_choice.eq(target.data.int()).sum() acc = correct.item() / float(target.size(0)) tp = torch.mul(pred_choice.data, target.data.int()).sum().item() + 0.00001 fp = pred_choice.gt(target.data.int()).sum().item() fn = pred_choice.lt(target.data.int()).sum().item() tn = correct.item() - tp iou = float(tp) / (tp + fp + fn) prec = float(tp) / (tp + fp) recall = float(tp) / (tp + fn) print('[%d: %d/%d] train loss: %f acc: %f iou: %f' \ % (epoch, i_batch, num_batch, loss.item(), acc, iou)) mean_prec = torch.cat((mean_prec, torch.tensor([prec])), 0) mean_recall = torch.cat((mean_recall, torch.tensor([recall])), 0) mean_acc = torch.cat((mean_acc, torch.tensor([acc])), 0) mean_iou = torch.cat((mean_iou, torch.tensor([iou])), 0) n_iter += 1 writer.add_scalar('train/epoch_loss', ep_loss / (i_batch + 1), epoch) return mean_acc, mean_prec, mean_iou, mean_recall, ep_loss / (i_batch + 1), n_iter
def test(cfg): num_classes = int(cfg['n_classes']) sample_size = int(cfg['fixed_size']) cfg['loss'] = cfg['loss'].split(' ') batch_size = 1 cfg['batch_size'] = batch_size epoch = eval(cfg['n_epochs']) #n_gf = int(cfg['num_gf']) input_size = int(cfg['data_dim']) trans_val = [] if cfg['rnd_sampling']: trans_val.append(ds.TestSampling(sample_size)) #if cfg['standardization']: # trans_val.append(ds.SampleStandardization()) if cfg['dataset'] == 'left_ifof_ss_sl': dataset = ds.LeftIFOFSupersetDataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), uniform_size=True, train=False, split_obj=True, with_gt=cfg['with_gt']) elif cfg['dataset'] == 'hcp20_graph': dataset = ds.HCP20Dataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), with_gt=cfg['with_gt'], #distance=T.Distance(norm=True,cat=False), return_edges=True, split_obj=True, train=False, load_one_full_subj=False, standardize=cfg['standardization']) elif cfg['dataset'] == 'left_ifof_ss_sl_graph': dataset = ds.LeftIFOFSupersetGraphDataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), train=False, split_obj=True, with_gt=cfg['with_gt']) elif cfg['dataset'] == 'left_ifof_emb': dataset = ds.EmbDataset(cfg['sub_list_test'], cfg['emb_dataset_dir'], cfg['gt_dataset_dir'], transform=transforms.Compose(trans_val), load_all=cfg['load_all_once'], precompute_graph=cfg['precompute_graph'], k_graph=int(cfg['knngraph'])) elif cfg['dataset'] == 'psb_airplane': dataset = ds.PsbAirplaneDataset(cfg['dataset_dir'], train=False) elif cfg['dataset'] == 'shapes': dataset = ds.ShapesDataset(cfg['dataset_dir'], train=False, multi_cat=cfg['multi_category']) elif cfg['dataset'] == 'shapenet': dataset = ds.ShapeNetCore(cfg['dataset_dir'], train=False, multi_cat=cfg['multi_category']) elif cfg['dataset'] == 'modelnet': dataset = ds.ModelNetDataset(cfg['dataset_dir'], split=cfg['mn40_split'], fold_size=int(cfg['mn40_fold_size']), load_all=cfg['load_all_once']) elif cfg['dataset'] == 'scanobj': dataset = ds.ScanObjNNDataset(cfg['dataset_dir'], run='test', variant=cfg['scanobj_variant'], background=cfg['scanobj_bg'], load_all=cfg['load_all_once']) else: dataset = ds.DRLeftIFOFSupersetDataset( cfg['sub_list_test'], cfg['val_dataset_dir'], transform=transforms.Compose(trans_val), with_gt=cfg['with_gt']) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) print("Validation dataset loaded, found %d samples" % (len(dataset))) for ext in range(100): logdir = '%s/test_%d' % (cfg['exp_path'], ext) if not os.path.exists(logdir): break writer = SummaryWriter(logdir) if cfg['weights_path'] == '': cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0] epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0]) elif 'ep-' in cfg['weights_path']: epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0]) tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1] tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1] os.system('mkdir -p %s' % tb_log_dir) os.system('ln -sr %s/%s %s/%s ' % (writer.logdir, tb_log_name, tb_log_dir, tb_log_name)) #### BUILD THE MODEL classifier = get_model(cfg) classifier.cuda() classifier.load_state_dict(torch.load(cfg['weights_path'])) classifier.eval() with torch.no_grad(): pred_buffer = {} sm_buffer = {} sm2_buffer = {} gf_buffer = {} emb_buffer = {} print('\n\n') mean_val_acc = torch.tensor([]) mean_val_iou = torch.tensor([]) mean_val_prec = torch.tensor([]) mean_val_recall = torch.tensor([]) if 'split_obj' in dir(dataset) and dataset.split_obj: split_obj = True else: split_obj = False dataset.transform = [] if split_obj: consumed = False else: consumed = True j = 0 visualized = 0 new_obj_read = True sls_count = 1 while j < len(dataset): #while sls_count <= len(dataset): data = dataset[j] if split_obj: if new_obj_read: obj_pred_choice = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda() obj_target = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda() new_obj_read = False #if cfg['save_embedding']: #obj_embedding = torch.empty((data['obj_full_size'], int(cfg['embedding_size']))).cuda() if len(dataset.remaining[j]) == 0: consumed = True sample_name = data['name'] if type( data['name']) == str else data['name'][0] #print(points) #if len(points.shape()) == 2: #points = points.unsqueeze(0) if 'graph' not in cfg['dataset']: points = data['points'] if cfg['with_gt']: target = data['gt'] target = target.to('cuda') target = target.view(-1, 1)[:, 0] else: #print(data) points = gBatch().from_data_list([data['points']]) #points = data['points'] if 'bvec' in points.keys: points.batch = points.bvec.clone() del points.bvec # points.ori_batch = torch.zeros(points.x.size(0)).long() if cfg['with_gt']: target = points['y'] target = target.to('cuda') target = target.view(-1, 1)[:, 0] if cfg['same_size']: points['lengths'] = points['lengths'][0].item() #if cfg['model'] == 'pointnet_cls': #points = points.view(len(data['obj_idxs']), -1, input_size) points = points.to('cuda') #print('streamline number:',sls_count) #sls_count+=1 #print('lengths:',points['lengths'].item()) ### add one-hot labels if multi-category task #new_k = points['lengths'].item()*(5/16) #print('new k:',new_k,'rounded k:',int(round(new_k))) #classifier.conv2.k = int(round(new_k)) #if cfg['multi_loss']: #logits, gf = classifier(points) #else: logits = classifier(points) logits = logits.view(-1, num_classes) if len(cfg['loss']) == 2: if epoch <= int(cfg['switch_loss_epoch']): loss_type = cfg['loss'][0] else: loss_type = cfg['loss'][1] else: loss_type = cfg['loss'][0] if loss_type == 'nll': pred = F.log_softmax(logits, dim=-1) pred = pred.view(-1, num_classes) probas = torch.exp(pred.data) pred_choice = pred.data.max(1)[1].int() if cfg['with_gt']: loss_seg = F.nll_loss(pred, target.long()) elif loss_type == 'LLh': pred_choice = (logits.data > 0).int() if cfg['with_gt']: loss_seg = L.lovasz_hinge( logits.view(batch_size, sample_size, 1), target.view(batch_size, sample_size, 1), per_image=False) #loss = L.lovasz_hinge_flat(pred.view(-1), target.view(-1)) elif loss_type == 'LLm': pred = F.softmax(logits, dim=-1) probas = pred.data pred_choice = pred.data.max(1)[1].int() if cfg['with_gt']: loss = L.lovasz_softmax_flat( pred, target, op=cfg['llm_op'], only_present=cfg['multi_category']) #print('pred:',pred) #print('pred shape:',pred.shape) #print('pred choice:',pred_choice) #print('pred choice shape:',pred_choice.shape) #if visualized < int(cfg['viz_clusters']): # visualized += 1 # colors = torch.from_numpy(get_spaced_colors(n_gf)) # sm_out = classifier.feat.mf.softmax_out[0,:,:].max(1)[1].squeeze().int() # writer.add_mesh('latent clustering', points, colors[sm_out.tolist()].unsqueeze(0)) # if 'bg' in data.keys(): # bg_msk = data['bg']*-1 # writer.add_mesh('bg_mask', points, colors[bg_msk.tolist()].unsqueeze(0)) if split_obj: obj_pred_choice[data['obj_idxs']] = pred_choice obj_target[data['obj_idxs']] = target.int() #if cfg['save_embedding']: # obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze() else: obj_data = points obj_pred_choice = pred_choice obj_target = target if cfg['save_embedding']: obj_embedding = classifier.embedding.squeeze() if cfg['with_gt'] and consumed: #if cfg['multi_loss']: # loss_cluster = cluster_loss_fn(gf.squeeze(3)) # loss = loss_seg + alfa * loss_cluster #pred_choice = torch.sigmoid(pred.view(-1,1)).data.round().type_as(target.data) #print('points:',points['streamlines']) #print('points shape:',points['streamlines'].shape) #print('streamlines:', data_dir = cfg['dataset_dir'] #streamlines, head, leng, idxs = load_streamlines(data['dir']+'/'+data['name']+'.trk') #print('tract:',len(streamlines)) #print('pred:',obj_pred_choice) #print('taget:',obj_target) #print('pred shape:',obj_pred_choice.shape) #print('target shape:',obj_target.shape) print('val max class red ', obj_pred_choice.max().item()) print('val min class pred ', obj_pred_choice.min().item()) y_pred = obj_pred_choice.cpu().numpy() np.save(data['dir'] + '/y_pred_sDEC_k5_16pts_nodropout', y_pred) y_test = obj_target.cpu().numpy() np.save(data['dir'] + '/y_test_sDEC_k5_16pts_nodropout', y_test) #np.save(data['dir']+'/streamlines_lstm_GIN',streamlines) correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum() acc = correct.item() / float(obj_target.size(0)) if num_classes > 2: iou, prec, recall = L.iou_multi( obj_pred_choice.data.int().cpu().numpy(), obj_target.data.int().cpu().numpy(), num_classes, multi_cat=cfg['multi_category']) assert (np.isnan(iou).sum() == 0) if cfg['multi_category']: s, n_parts = data['gt_offset'] e = s + n_parts iou[0, :s], prec[0, :s], recall[0, :s] = 0., 0., 0. iou[0, e:], prec[0, e:], recall[0, e:] = 0., 0., 0. iou = torch.from_numpy(iou).float() prec = torch.from_numpy(prec).float() recall = torch.from_numpy(recall).float() category = data['category'].squeeze().nonzero().float() iou = torch.cat([iou, category], 1) else: iou = torch.tensor([iou.mean()]) prec = torch.tensor([prec.mean()]) recall = torch.tensor([recall.mean()]) assert (torch.isnan(iou).sum() == 0) else: tp = torch.mul( obj_pred_choice.data, obj_target.data.int()).cpu().sum().item() + 0.00001 fp = obj_pred_choice.gt( obj_target.data.int()).cpu().sum().item() fn = obj_pred_choice.lt( obj_target.data.int()).cpu().sum().item() tn = correct.item() - tp iou = torch.tensor([float(tp) / (tp + fp + fn)]) prec = torch.tensor([float(tp) / (tp + fp)]) recall = torch.tensor([float(tp) / (tp + fn)]) mean_val_prec = torch.cat((mean_val_prec, prec), 0) mean_val_recall = torch.cat((mean_val_recall, recall), 0) mean_val_iou = torch.cat((mean_val_iou, iou), 0) mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0) print('VALIDATION [%d: %d/%d] val accuracy: %f' \ % (epoch, j, len(dataset), acc)) if cfg['save_pred'] and consumed: print('buffering prediction %s' % sample_name) if num_classes > 2: for cl in range(num_classes): sl_idx = np.where( obj_pred.data.cpu().view(-1).numpy() == cl)[0] if cl == 0: pred_buffer[sample_name] = [] pred_buffer[sample_name].append(sl_idx.tolist()) else: sl_idx = np.where( obj_pred.data.cpu().view(-1).numpy() == 1)[0] pred_buffer[sample_name] = sl_idx.tolist() #if cfg['save_softmax_out']: # if cfg['model'] in 'pointnet_mgfml': # if sample_name not in sm_buffer.keys(): # sm_buffer[sample_name] = [] # if classifier.feat.multi_feat > 1: # sm_buffer[sample_name].append( # classifier.feat.mf.softmax_out.cpu().numpy()) # if cfg['model'] == 'pointnet_mgfml': # for l in classifier.layers: # sm_buffer[sample_name].append( # l.mf.softmax_out.cpu().numpy()) # sm2_buffer[sample_name] = probas.cpu().numpy() #if cfg['save_gf']: # gf_buffer[sample_name] = np.unique( # classifier.feat.globalfeat.data.cpu().squeeze().numpy(), axis = 0) #gf_buffer[sample_name] = classifier.globalfeat #if cfg['save_embedding'] and consumed: # emb_buffer[sample_name] = obj_embedding if consumed: print(j) j += 1 if split_obj: consumed = False new_obj_read = True macro_iou = torch.mean(mean_val_iou) macro_prec = torch.mean(mean_val_prec) macro_recall = torch.mean(mean_val_recall) epoch_iou = macro_iou.item() if cfg['save_pred']: #os.system('rm -r %s/predictions_test*' % writer.logdir) pred_dir = writer.logdir + '/predictions_test_%d' % epoch if not os.path.exists(pred_dir): os.makedirs(pred_dir) print('saving files') for filename, value in pred_buffer.items(): with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f: pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL) #if cfg['save_softmax_out']: # os.system('rm -r %s/sm_out_test*' % writer.logdir) # sm_dir = writer.logdir + '/sm_out_test_%d' % epoch # if not os.path.exists(sm_dir): # os.makedirs(sm_dir) # for filename, value in sm_buffer.iteritems(): # with open(os.path.join(sm_dir, filename) + '_sm_1.pkl', 'wb') as f: # pickle.dump( # value, f, protocol=pickle.HIGHEST_PROTOCOL) # for filename, value in sm2_buffer.iteritems(): # with open(os.path.join(sm_dir, filename) + '_sm_2.pkl', 'wb') as f: # pickle.dump( # value, f, protocol=pickle.HIGHEST_PROTOCOL) #if cfg['save_gf']: #os.system('rm -r %s/gf_test*' % writer.logdir) # gf_dir = writer.logdir + '/gf_test_%d' % epoch # if not os.path.exists(gf_dir): # os.makedirs(gf_dir) # i = 0 # for filename, value in gf_buffer.items(): # if i == 3: # break # with open(os.path.join(gf_dir, filename) + '.pkl', 'wb') as f: # pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL) #if cfg['save_embedding']: # print('saving embedding') # emb_dir = writer.logdir + '/embedding_test_%d' % epoch # if not os.path.exists(emb_dir): # os.makedirs(emb_dir) # for filename, value in emb_buffer.iteritems(): # np.save(os.path.join(emb_dir, filename), value.cpu().numpy()) if cfg['with_gt']: print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item()) print('TEST PRECISION: %f' % macro_prec.item()) print('TEST RECALL: %f' % macro_recall.item()) print('TEST IOU: %f' % macro_iou.item()) mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec + mean_val_recall) final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch scores_file = writer.logdir + '/scores_test_%d.txt' % epoch print('saving scores') with open(scores_file, 'w') as f: f.write('acc\n') f.writelines('%f\n' % v for v in mean_val_acc.tolist()) f.write('prec\n') f.writelines('%f\n' % v for v in mean_val_prec.tolist()) f.write('recall\n') f.writelines('%f\n' % v for v in mean_val_recall.tolist()) f.write('dsc\n') f.writelines('%f\n' % v for v in mean_val_dsc.tolist()) f.write('iou\n') f.writelines('%f\n' % v for v in mean_val_iou.tolist()) with open(final_scores_file, 'w') as f: f.write('acc\n') f.write('%f\n' % mean_val_acc.mean()) f.write('%f\n' % mean_val_acc.std()) f.write('prec\n') f.write('%f\n' % mean_val_prec.mean()) f.write('%f\n' % mean_val_prec.std()) f.write('recall\n') f.write('%f\n' % mean_val_recall.mean()) f.write('%f\n' % mean_val_recall.std()) f.write('dsc\n') f.write('%f\n' % mean_val_dsc.mean()) f.write('%f\n' % mean_val_dsc.std()) f.write('iou\n') f.write('%f\n' % mean_val_iou.mean()) f.write('%f\n' % mean_val_iou.std()) print('\n\n')
def test(cfg): num_classes = int(cfg['n_classes']) sample_size = int(cfg['fixed_size']) cfg['loss'] = cfg['loss'].split(' ') batch_size = 1 cfg['batch_size'] = batch_size epoch = eval(str(cfg['n_epochs'])) #n_gf = int(cfg['num_gf']) input_size = int(cfg['data_dim']) trans_val = [] if cfg['rnd_sampling']: trans_val.append(TestSampling(sample_size)) if cfg['standardization']: trans_val.append(SampleStandardization()) if cfg['dataset'] == 'hcp20_graph': dataset = ds.HCP20Dataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), with_gt=cfg['with_gt'], #distance=T.Distance(norm=True,cat=False), return_edges=True, split_obj=True, train=False, load_one_full_subj=False, labels_dir=cfg['labels_dir']) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) print("Validation dataset loaded, found %d samples" % (len(dataset))) for ext in range(100): logdir = '%s/test_%d' % (cfg['exp_path'], ext) if not os.path.exists(logdir): break writer = SummaryWriter(logdir) if cfg['weights_path'] == '': cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0] epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0]) elif 'ep-' in cfg['weights_path']: epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0]) tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1] tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1] os.system('mkdir -p %s' % tb_log_dir) os.system('ln -sr %s/%s %s/%s ' % (writer.logdir, tb_log_name, tb_log_dir, tb_log_name)) #### BUILD THE MODEL classifier = get_model(cfg) classifier.cuda() classifier.load_state_dict(torch.load(cfg['weights_path'])) classifier.eval() with torch.no_grad(): pred_buffer = {} sm_buffer = {} sm2_buffer = {} gf_buffer = {} emb_buffer = {} print('\n\n') #mean_val_acc = torch.tensor([]) #mean_val_iou = torch.tensor([]) #mean_val_prec = torch.tensor([]) #mean_val_recall = torch.tensor([]) mean_val_mse = torch.tensor([]) mean_val_mae = torch.tensor([]) mean_val_rho = torch.tensor([]) if 'split_obj' in dir(dataset) and dataset.split_obj: split_obj = True else: split_obj = False dataset.transform = [] if split_obj: consumed = False else: consumed = True j = 0 visualized = 0 new_obj_read = True sls_count = 1 while j < len(dataset): #while sls_count <= len(dataset): data = dataset[j] if split_obj: if new_obj_read: #obj_pred_choice = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda() #obj_target = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda() obj_pred_choice = torch.zeros(data['obj_full_size'], dtype=torch.float32).cuda() obj_target = torch.zeros(data['obj_full_size'], dtype=torch.float32).cuda() new_obj_read = False if len(dataset.remaining[j]) == 0: consumed = True sample_name = data['name'] if type( data['name']) == str else data['name'][0] print(sample_name) #print(points) #if len(points.shape()) == 2: #points = points.unsqueeze(0) #print(data) points = gBatch().from_data_list([data['points']]) #points = data['points'] if 'bvec' in points.keys: points.batch = points.bvec.clone() del points.bvec if cfg['with_gt']: target = points['y'] target = target.to('cuda') if cfg['same_size']: points['lengths'] = points['lengths'][0].item() #if cfg['model'] == 'pointnet_cls': #points = points.view(len(data['obj_idxs']), -1, input_size) points = points.to('cuda') pred = classifier(points) #logits = classifier(points) #logits = logits.view(-1, num_classes) #pred = F.log_softmax(logits, dim=-1).view(-1, num_classes) #pred_choice = pred.data.max(1)[1].int() if split_obj: obj_pred_choice[data['obj_idxs']] = pred.view(-1) #obj_pred_choice[data['obj_idxs']] = pred_choice obj_target[data['obj_idxs']] = target.float() #print(obj_pred_choice) #print(obj_target) #obj_target[data['obj_idxs']] = target.int() #if cfg['save_embedding']: # obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze() else: obj_data = points obj_pred_choice = pred_choice obj_target = target if cfg['save_embedding']: obj_embedding = classifier.embedding.squeeze() if cfg['with_gt'] and consumed: print('val max class pred ', obj_pred_choice.max().item()) print('val min class pred ', obj_pred_choice.min().item()) print('val max class target ', obj_target.max().item()) print('val min class target ', obj_target.min().item()) #obj_pred_choice = obj_pred_choice.view(-1,1) #obj_target = obj_target.view(-1,1) mae = torch.mean( abs(obj_target.data.cpu() - obj_pred_choice.data.cpu())).item() mse = torch.mean((obj_target.data.cpu() - obj_pred_choice.data.cpu())**2).item() rho, pval = spearmanr(obj_target.data.cpu().numpy(), obj_pred_choice.data.cpu().numpy()) np.save(writer.logdir + '/predictions_' + sample_name + '.npy', obj_pred_choice.data.cpu().numpy()) #correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum() #acc = correct.item()/float(obj_target.size(0)) #tp = torch.mul(obj_pred_choice.data, obj_target.data.int()).cpu().sum().item()+0.00001 #fp = obj_pred_choice.gt(obj_target.data.int()).cpu().sum().item() #fn = obj_pred_choice.lt(obj_target.data.int()).cpu().sum().item() #tn = correct.item() - tp #iou = torch.tensor([float(tp)/(tp+fp+fn)]) #prec = torch.tensor([float(tp)/(tp+fp)]) #recall = torch.tensor([float(tp)/(tp+fn)]) mean_val_mae = torch.cat((mean_val_mae, torch.tensor([mae])), 0) mean_val_mse = torch.cat((mean_val_mse, torch.tensor([mse])), 0) mean_val_rho = torch.cat((mean_val_rho, torch.tensor([rho])), 0) #mean_val_prec = torch.cat((mean_val_prec, prec), 0) #mean_val_recall = torch.cat((mean_val_recall, recall), 0) #mean_val_iou = torch.cat((mean_val_iou, iou), 0) #mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0) print('VALIDATION [%d: %d/%d] val mse: %f val mae: %f val rho: %f' \ % (epoch, j, len(dataset), mse, mae, rho)) if cfg['save_pred'] and consumed: print('buffering prediction %s' % sample_name) #sl_idx = np.where(obj_pred.data.cpu().view(-1).numpy() == 1)[0] #pred_buffer[sample_name] = sl_idx.tolist() if consumed: print(j) j += 1 if split_obj: consumed = False new_obj_read = True #macro_iou = torch.mean(mean_val_iou) #macro_prec = torch.mean(mean_val_prec) #macro_recall = torch.mean(mean_val_recall) #epoch_iou = macro_iou.item() #if cfg['save_pred']: #os.system('rm -r %s/predictions_test*' % writer.logdir) # pred_dir = writer.logdir + '/predictions_test_%d' % epoch # if not os.path.exists(pred_dir): # os.makedirs(pred_dir) #print('saving files') #for filename, value in pred_buffer.items(): # with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f: # pickle.dump( # value, f, protocol=pickle.HIGHEST_PROTOCOL) if cfg['with_gt']: print('TEST MSE: %f' % torch.mean(mean_val_mse).item()) print('TEST MAE: %f' % torch.mean(mean_val_mae).item()) print('TEST RHO: %f' % torch.mean(mean_val_rho).item()) #print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item()) #print('TEST PRECISION: %f' % macro_prec.item()) #print('TEST RECALL: %f' % macro_recall.item()) #print('TEST IOU: %f' % macro_iou.item()) #mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec + mean_val_recall) final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch scores_file = writer.logdir + '/scores_test_%d.txt' % epoch print('saving scores') with open(scores_file, 'w') as f: f.write('mse\n') f.writelines('%f\n' % v for v in mean_val_mse.tolist()) f.write('mae\n') f.writelines('%f\n' % v for v in mean_val_mae.tolist()) f.write('rho\n') f.writelines('%f\n' % v for v in mean_val_rho.tolist()) #f.write('acc\n') #f.writelines('%f\n' % v for v in mean_val_acc.tolist()) #f.write('prec\n') #f.writelines('%f\n' % v for v in mean_val_prec.tolist()) #f.write('recall\n') #f.writelines('%f\n' % v for v in mean_val_recall.tolist()) #f.write('dsc\n') #f.writelines('%f\n' % v for v in mean_val_dsc.tolist()) #f.write('iou\n') #f.writelines('%f\n' % v for v in mean_val_iou.tolist()) with open(final_scores_file, 'w') as f: f.write('mse\n') f.write('%f\n' % mean_val_mse.mean()) f.write('%f\n' % mean_val_mse.std()) f.write('mae\n') f.write('%f\n' % mean_val_mae.mean()) f.write('%f\n' % mean_val_mae.std()) f.write('rho\n') f.write('%f\n' % mean_val_rho.mean()) f.write('%f\n' % mean_val_rho.std()) #f.write('acc\n') #f.write('%f\n' % mean_val_acc.mean()) #f.write('%f\n' % mean_val_acc.std()) #f.write('prec\n') #f.write('%f\n' % mean_val_prec.mean()) #f.write('%f\n' % mean_val_prec.std()) #f.write('recall\n') #f.write('%f\n' % mean_val_recall.mean()) #f.write('%f\n' % mean_val_recall.std()) #f.write('dsc\n') #f.write('%f\n' % mean_val_dsc.mean()) #f.write('%f\n' % mean_val_dsc.std()) #f.write('iou\n') #f.write('%f\n' % mean_val_iou.mean()) #f.write('%f\n' % mean_val_iou.std()) print('\n\n')
def val_iter(cfg, val_dataloader, classifier, writer, epoch, best_epoch, best_pred, logdir): num_classes = int(cfg['n_classes']) batch_size = 1 sample_size = int(cfg['fixed_size']) ep_loss = 0. classifier.eval() with torch.no_grad(): print('\n\n') mean_val_acc = torch.tensor([]) mean_val_iou = torch.tensor([]) mean_val_prec = torch.tensor([]) mean_val_recall = torch.tensor([]) mean_val_iou_c = torch.tensor([]) for j, data in enumerate(val_dataloader): data_list = [] name_list = [] for i, d in enumerate(data): if 'bvec' in d['points'].keys: d['points'].bvec += sample_size * i data_list.append(d['points']) name_list.append(d['name']) points = gBatch().from_data_list(data_list) if 'bvec' in points.keys: points.batch = points.bvec.clone() del points.bvec target = points['y'] if cfg['same_size']: points['lengths'] = points['lengths'][0].item() data = {'points': points, 'gt': target, 'name': name_list} points, target = points.to('cuda'), target.to('cuda') logits = classifier(points) pred = F.log_softmax(logits, dim=-1) pred = pred.view(-1, num_classes) pred_choice = pred.data.max(1)[1].int() loss = F.nll_loss(pred, target.long()) ep_loss += loss print('val max class pred ', pred_choice.max().item()) print('val min class pred ', pred_choice.min().item()) print('# class pred ', len(np.unique(pred_choice.cpu().numpy()))) correct = pred_choice.eq(target.data.int()).cpu().sum() acc = correct.item() / float(target.size(0)) tp = torch.mul(pred_choice.data, target.data.int()).cpu().sum().item() + 0.00001 fp = pred_choice.gt(target.data.int()).cpu().sum().item() fn = pred_choice.lt(target.data.int()).cpu().sum().item() tn = correct.item() - tp iou = torch.tensor([float(tp) / (tp + fp + fn)]) prec = torch.tensor([float(tp) / (tp + fp)]) recall = torch.tensor([float(tp) / (tp + fn)]) print('VALIDATION [%d: %d/%d] val loss: %f acc: %f iou: %f' % (epoch, j, len(val_dataloader), loss, acc, iou)) mean_val_prec = torch.cat((mean_val_prec, prec), 0) mean_val_recall = torch.cat((mean_val_recall, recall), 0) mean_val_iou = torch.cat((mean_val_iou, iou), 0) mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0) macro_iou = torch.mean(mean_val_iou) macro_prec = torch.mean(mean_val_prec) macro_recall = torch.mean(mean_val_recall) macro_iou_c = torch.mean(mean_val_iou_c) epoch_iou = macro_iou.item() writer.add_scalar('val/epoch_acc', torch.mean(mean_val_acc).item(), epoch) writer.add_scalar('val/epoch_iou', epoch_iou, epoch) writer.add_scalar('val/epoch_prec', macro_prec.item(), epoch) writer.add_scalar('val/epoch_recall', macro_recall.item(), epoch) writer.add_scalar('val/epoch_iou_c', macro_iou_c.item(), epoch) writer.add_scalar('val/loss', ep_loss / j, epoch) print('VALIDATION ACCURACY: %f' % torch.mean(mean_val_acc).item()) print('VALIDATION IOU: %f' % epoch_iou) print('VALIDATION IOUC: %f' % macro_iou_c.item()) print('\n\n') return best_epoch, best_pred, ep_loss
def train_iter(cfg, dataloader, classifier, optimizer, writer, epoch, n_iter, cluster_loss_fn): num_classes = int(cfg['n_classes']) batch_size = int(cfg['batch_size']) n_epochs = int(cfg['n_epochs']) sample_size = int(cfg['fixed_size']) input_size = int(cfg['data_dim']) num_batch = cfg['num_batch'] alfa = 0 ep_loss = 0. ep_seg_loss = 0. ep_cluster_loss = 0. mean_acc = torch.tensor([]) mean_iou = torch.tensor([]) mean_prec = torch.tensor([]) mean_recall = torch.tensor([]) ### state that the model will run in train mode classifier.train() #d_list=[] #for dat in dataloader: #for d in dat: #d_list.append(d) #points = gBatch().from_data_list(d_list) #target = points['y'] #name = dataset['name'] #points, target = points.to('cuda'), target.to('cuda') for i_batch, sample_batched in enumerate(dataloader): ### get batch data_list = [] name_list = [] for i, d in enumerate(sample_batched): if 'bvec' in d['points'].keys: d['points'].bvec += sample_size * i data_list.append(d['points']) name_list.append(d['name']) points = gBatch().from_data_list(data_list) if 'bvec' in points.keys: #points.batch = points.bvec.copy() points.batch = points.bvec.clone() del points.bvec #if 'bslices' in points.keys(): # points.__slices__ = torch.cum( target = points['y'] if cfg['same_size']: points['lengths'] = points['lengths'][0].item() sample_batched = {'points': points, 'gt': target, 'name': name_list} #print('points:',points) #if (epoch != 0) and (epoch % 20 == 0): # assert(len(dataloader.dataset) % int(cfg['fold_size']) == 0) # folds = len(dataloader.dataset)/int(cfg['fold_size']) # n_fold = (dataloader.dataset.n_fold + 1) % folds # if n_fold != dataloader.dataset.n_fold: # dataloader.dataset.n_fold = n_fold # dataloader.dataset.load_fold() points, target = points.to('cuda'), target.to('cuda') #print(len(points.lengths),target.shape) ### initialize gradients #if not cfg['accumulation_interval'] or i_batch == 0: optimizer.zero_grad() ### forward logits = classifier(points) ### minimize the loss pred = F.log_softmax(logits, dim=-1) pred = pred.view(-1, num_classes) pred_choice = pred.data.max(1)[1].int() loss = F.nll_loss(pred, target.long()) ep_loss += loss #print('memory allocated in MB: ', torch.cuda.memory_allocated()/2**20) #import sys; sys.exit() loss.backward() #if int(cfg['accumulation_interval']) % (i_batch+1) == 0: optimizer.step() #optimizer.zero_grad #elif not cfg['accumulation_interval']: # optimizer.step() ### compute performance correct = pred_choice.eq(target.data.int()).sum() acc = correct.item() / float(target.size(0)) tp = torch.mul(pred_choice.data, target.data.int()).sum().item() + 0.00001 fp = pred_choice.gt(target.data.int()).sum().item() fn = pred_choice.lt(target.data.int()).sum().item() tn = correct.item() - tp iou = float(tp) / (tp + fp + fn) prec = float(tp) / (tp + fp) recall = float(tp) / (tp + fn) print('[%d: %d/%d] train loss: %f acc: %f iou: %f' \ % (epoch, i_batch, num_batch, loss.item(), acc, iou)) mean_prec = torch.cat((mean_prec, torch.tensor([prec])), 0) mean_recall = torch.cat((mean_recall, torch.tensor([recall])), 0) mean_acc = torch.cat((mean_acc, torch.tensor([acc])), 0) mean_iou = torch.cat((mean_iou, torch.tensor([iou])), 0) n_iter += 1 writer.add_scalar('train/epoch_loss', ep_loss / (i_batch + 1), epoch) return mean_acc, mean_prec, mean_iou, mean_recall, ep_loss / (i_batch + 1), n_iter