def train(train_loader, model): running_loss = 0.0 data_size = train_data.__len__() model.train() # for inputs, masks, labels in progress_bar(train_loader, parent=mb): for inputs, masks, labels in train_loader: inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(True): if args.is_pseudo: logit, logit_pixel, logit_image = model(inputs) loss1 = lovasz_hinge(logit.squeeze(1), masks.squeeze(1)) loss2 = nn.BCELoss()(logit_image, labels) loss3 = lovasz_hinge2(logit_pixel.squeeze(1), masks.squeeze(1)) loss = loss1 + loss2 + loss3 else: logit = model(inputs) loss = lovasz_hinge(logit.squeeze(1), masks.squeeze(1)) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) # mb.child.comment = 'loss: {}'.format(loss.item()) epoch_loss = running_loss / data_size return epoch_loss
def criterion(args, output, target, epoch=0): mask_output, ship_output = output mask_target, ship_target = target #dice_loss = mixed_dice_bce_loss(mask_output, mask_target) focal_loss = focal_loss2d(mask_output, mask_target) #lovasz_loss = lovasz_hinge(mask_output, mask_target) lovasz_loss = (lovasz_hinge(mask_output, mask_target) + lovasz_hinge(-mask_output, 1 - mask_target)) / 2 bce_loss = F.binary_cross_entropy_with_logits(mask_output, mask_target) cls_loss = F.binary_cross_entropy_with_logits(ship_output, ship_target) if args.train_cls: #cls_loss = F.binary_cross_entropy_with_logits(ship_output, ship_target) return lovasz_loss + bce_loss + cls_loss, focal_loss.item( ), lovasz_loss.item(), bce_loss.item(), cls_loss.item() # four losses for: 1. grad, 2, display, 3, display 4, measurement #if epoch < 10: # return bce_loss, focal_loss.item(), lovasz_loss.item(), 0., lovasz_loss.item() + focal_loss.item()*focal_weight #else: #return lovasz_loss+focal_loss*focal_weight, focal_loss.item(), lovasz_loss.item(), 0., lovasz_loss.item() + focal_loss.item()*focal_weight return lovasz_loss + bce_loss * 0.1, focal_loss.item(), lovasz_loss.item( ), bce_loss.item(), cls_loss.item()
def test(test_loader, model): running_loss = 0.0 predicts = [] truths = [] model.eval() for inputs, masks in test_loader: inputs, masks = inputs.to(device), masks.to(device) with torch.set_grad_enabled(False): outputs = model(inputs) outputs = outputs[:, :, args.pad_left:args.pad_left + args.fine_size, args.pad_left:args.pad_left + args.fine_size].contiguous() loss = lovasz_hinge(outputs.squeeze(1), masks.squeeze(1)) predicts.append(F.sigmoid(outputs).detach().cpu().numpy()) truths.append(masks.detach().cpu().numpy()) running_loss += loss.item() * inputs.size(0) predicts = np.concatenate(predicts).squeeze() truths = np.concatenate(truths).squeeze() precision, _, _ = do_kaggle_metric(predicts, truths, 0.52) precision = precision.mean() epoch_loss = running_loss / val_data.__len__() return epoch_loss, precision
def weighted_loss(output, target, epoch=0): mask_output, _ = output mask_target, _ = target lovasz_loss = lovasz_hinge(mask_output, mask_target) dice_loss = mixed_dice_bce_loss(mask_output, mask_target) #print(bce_loss, lovasz_loss) if epoch < 5: return dice_loss else: return lovasz_loss #, lovasz_loss.item(), bce_loss.item()
def criterion(self, logit_clf, truth, logit_mask=None, mask=None): """Define the (customized) loss function here.""" ## 1. classification loss Loss_FUNC = FocalLoss() #Loss_FUNC = nn.BCEWithLogitsLoss()#nn.MultiLabelSoftMarginLoss() loss_clf = Loss_FUNC(logit_clf, truth) if logit_mask is not None: ## 2. segmentation mask loss loss_mask = L.lovasz_hinge(logit_mask, mask, ignore=255) return loss_clf, loss_mask else: return loss_clf
def forward(self, input, target): pred = input.view(-1) truth = target.view(-1) bce_loss = nn.BCEWithLogitsLoss()(pred, truth).double() # lovasz loss lovasz_loss = L.lovasz_hinge(input, target, per_image=False) loss = bce_loss + lovasz_loss.double() return loss, bce_loss, lovasz_loss.double()
def __call__(self, logits, labels): loss = (1 - self.jaccard_weight) * (lovasz_hinge( logits, labels, per_image=True, ignore=None)) + self.focal_loss( logits, labels) * self.focal_weight if self.jaccard_weight: eps = 1e-15 jaccard_target = (labels == 1).float() jaccard_output = F.sigmoid(logits) intersection = (jaccard_output * jaccard_target).sum() union = jaccard_output.sum() + jaccard_target.sum() loss -= self.jaccard_weight * torch.log( (intersection + eps) / (union - intersection + eps)) return loss
def weighted_loss(args, output, target, epoch=0): mask_output, salt_output = output mask_target, salt_target = target lovasz_loss = lovasz_hinge(mask_output, mask_target) focal_loss = focal_loss2d(mask_output, mask_target) focal_weight = 0.2 if salt_output is not None and args.train_cls: salt_loss = F.binary_cross_entropy_with_logits(salt_output, salt_target) return salt_loss, focal_loss.item(), lovasz_loss.item( ), salt_loss.item( ), lovasz_loss.item() + focal_loss.item() * focal_weight return lovasz_loss + focal_loss * focal_weight, focal_loss.item( ), lovasz_loss.item( ), 0., lovasz_loss.item() + focal_loss.item() * focal_weight
def train(train_loader, model): running_loss = 0.0 model.train() # for inputs, masks, labels in progress_bar(train_loader, parent=mb): for inputs, masks, labels in train_loader: inputs, masks, labels = inputs.to(device), masks.to(device), labels.to( device) optimizer.zero_grad() with torch.set_grad_enabled(True): logit = model(inputs) loss = lovasz_hinge(logit.squeeze(1), masks.squeeze(1)) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) # mb.child.comment = 'loss: {}'.format(loss.item()) epoch_loss = running_loss / train_data.__len__() return epoch_loss
def __call__(self, logits, labels): return lovasz_hinge( logits, labels, self.per_image, self.ignore) + self.focal_loss(logits, labels) * self.focal_weight
def test(cfg): num_classes = int(cfg['n_classes']) sample_size = int(cfg['fixed_size']) cfg['loss'] = cfg['loss'].split(' ') batch_size = 1 cfg['batch_size'] = batch_size epoch = eval(cfg['n_epochs']) #n_gf = int(cfg['num_gf']) input_size = int(cfg['data_dim']) trans_val = [] if cfg['rnd_sampling']: trans_val.append(ds.TestSampling(sample_size)) #if cfg['standardization']: # trans_val.append(ds.SampleStandardization()) if cfg['dataset'] == 'left_ifof_ss_sl': dataset = ds.LeftIFOFSupersetDataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), uniform_size=True, train=False, split_obj=True, with_gt=cfg['with_gt']) elif cfg['dataset'] == 'hcp20_graph': dataset = ds.HCP20Dataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), with_gt=cfg['with_gt'], #distance=T.Distance(norm=True,cat=False), return_edges=True, split_obj=True, train=False, load_one_full_subj=False, standardize=cfg['standardization']) elif cfg['dataset'] == 'left_ifof_ss_sl_graph': dataset = ds.LeftIFOFSupersetGraphDataset( cfg['sub_list_test'], cfg['dataset_dir'], transform=transforms.Compose(trans_val), train=False, split_obj=True, with_gt=cfg['with_gt']) elif cfg['dataset'] == 'left_ifof_emb': dataset = ds.EmbDataset(cfg['sub_list_test'], cfg['emb_dataset_dir'], cfg['gt_dataset_dir'], transform=transforms.Compose(trans_val), load_all=cfg['load_all_once'], precompute_graph=cfg['precompute_graph'], k_graph=int(cfg['knngraph'])) elif cfg['dataset'] == 'psb_airplane': dataset = ds.PsbAirplaneDataset(cfg['dataset_dir'], train=False) elif cfg['dataset'] == 'shapes': dataset = ds.ShapesDataset(cfg['dataset_dir'], train=False, multi_cat=cfg['multi_category']) elif cfg['dataset'] == 'shapenet': dataset = ds.ShapeNetCore(cfg['dataset_dir'], train=False, multi_cat=cfg['multi_category']) elif cfg['dataset'] == 'modelnet': dataset = ds.ModelNetDataset(cfg['dataset_dir'], split=cfg['mn40_split'], fold_size=int(cfg['mn40_fold_size']), load_all=cfg['load_all_once']) elif cfg['dataset'] == 'scanobj': dataset = ds.ScanObjNNDataset(cfg['dataset_dir'], run='test', variant=cfg['scanobj_variant'], background=cfg['scanobj_bg'], load_all=cfg['load_all_once']) else: dataset = ds.DRLeftIFOFSupersetDataset( cfg['sub_list_test'], cfg['val_dataset_dir'], transform=transforms.Compose(trans_val), with_gt=cfg['with_gt']) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) print("Validation dataset loaded, found %d samples" % (len(dataset))) for ext in range(100): logdir = '%s/test_%d' % (cfg['exp_path'], ext) if not os.path.exists(logdir): break writer = SummaryWriter(logdir) if cfg['weights_path'] == '': cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0] epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0]) elif 'ep-' in cfg['weights_path']: epoch = int(cfg['weights_path'].rsplit('-', 1)[1].split('.')[0]) tb_log_name = glob.glob('%s/events*' % writer.logdir)[0].rsplit('/', 1)[1] tb_log_dir = 'tb_logs/%s' % logdir.split('/', 1)[1] os.system('mkdir -p %s' % tb_log_dir) os.system('ln -sr %s/%s %s/%s ' % (writer.logdir, tb_log_name, tb_log_dir, tb_log_name)) #### BUILD THE MODEL classifier = get_model(cfg) classifier.cuda() classifier.load_state_dict(torch.load(cfg['weights_path'])) classifier.eval() with torch.no_grad(): pred_buffer = {} sm_buffer = {} sm2_buffer = {} gf_buffer = {} emb_buffer = {} print('\n\n') mean_val_acc = torch.tensor([]) mean_val_iou = torch.tensor([]) mean_val_prec = torch.tensor([]) mean_val_recall = torch.tensor([]) if 'split_obj' in dir(dataset) and dataset.split_obj: split_obj = True else: split_obj = False dataset.transform = [] if split_obj: consumed = False else: consumed = True j = 0 visualized = 0 new_obj_read = True sls_count = 1 while j < len(dataset): #while sls_count <= len(dataset): data = dataset[j] if split_obj: if new_obj_read: obj_pred_choice = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda() obj_target = torch.zeros(data['obj_full_size'], dtype=torch.int).cuda() new_obj_read = False #if cfg['save_embedding']: #obj_embedding = torch.empty((data['obj_full_size'], int(cfg['embedding_size']))).cuda() if len(dataset.remaining[j]) == 0: consumed = True sample_name = data['name'] if type( data['name']) == str else data['name'][0] #print(points) #if len(points.shape()) == 2: #points = points.unsqueeze(0) if 'graph' not in cfg['dataset']: points = data['points'] if cfg['with_gt']: target = data['gt'] target = target.to('cuda') target = target.view(-1, 1)[:, 0] else: #print(data) points = gBatch().from_data_list([data['points']]) #points = data['points'] if 'bvec' in points.keys: points.batch = points.bvec.clone() del points.bvec # points.ori_batch = torch.zeros(points.x.size(0)).long() if cfg['with_gt']: target = points['y'] target = target.to('cuda') target = target.view(-1, 1)[:, 0] if cfg['same_size']: points['lengths'] = points['lengths'][0].item() #if cfg['model'] == 'pointnet_cls': #points = points.view(len(data['obj_idxs']), -1, input_size) points = points.to('cuda') #print('streamline number:',sls_count) #sls_count+=1 #print('lengths:',points['lengths'].item()) ### add one-hot labels if multi-category task #new_k = points['lengths'].item()*(5/16) #print('new k:',new_k,'rounded k:',int(round(new_k))) #classifier.conv2.k = int(round(new_k)) #if cfg['multi_loss']: #logits, gf = classifier(points) #else: logits = classifier(points) logits = logits.view(-1, num_classes) if len(cfg['loss']) == 2: if epoch <= int(cfg['switch_loss_epoch']): loss_type = cfg['loss'][0] else: loss_type = cfg['loss'][1] else: loss_type = cfg['loss'][0] if loss_type == 'nll': pred = F.log_softmax(logits, dim=-1) pred = pred.view(-1, num_classes) probas = torch.exp(pred.data) pred_choice = pred.data.max(1)[1].int() if cfg['with_gt']: loss_seg = F.nll_loss(pred, target.long()) elif loss_type == 'LLh': pred_choice = (logits.data > 0).int() if cfg['with_gt']: loss_seg = L.lovasz_hinge( logits.view(batch_size, sample_size, 1), target.view(batch_size, sample_size, 1), per_image=False) #loss = L.lovasz_hinge_flat(pred.view(-1), target.view(-1)) elif loss_type == 'LLm': pred = F.softmax(logits, dim=-1) probas = pred.data pred_choice = pred.data.max(1)[1].int() if cfg['with_gt']: loss = L.lovasz_softmax_flat( pred, target, op=cfg['llm_op'], only_present=cfg['multi_category']) #print('pred:',pred) #print('pred shape:',pred.shape) #print('pred choice:',pred_choice) #print('pred choice shape:',pred_choice.shape) #if visualized < int(cfg['viz_clusters']): # visualized += 1 # colors = torch.from_numpy(get_spaced_colors(n_gf)) # sm_out = classifier.feat.mf.softmax_out[0,:,:].max(1)[1].squeeze().int() # writer.add_mesh('latent clustering', points, colors[sm_out.tolist()].unsqueeze(0)) # if 'bg' in data.keys(): # bg_msk = data['bg']*-1 # writer.add_mesh('bg_mask', points, colors[bg_msk.tolist()].unsqueeze(0)) if split_obj: obj_pred_choice[data['obj_idxs']] = pred_choice obj_target[data['obj_idxs']] = target.int() #if cfg['save_embedding']: # obj_embedding[data['obj_idxs']] = classifier.embedding.squeeze() else: obj_data = points obj_pred_choice = pred_choice obj_target = target if cfg['save_embedding']: obj_embedding = classifier.embedding.squeeze() if cfg['with_gt'] and consumed: #if cfg['multi_loss']: # loss_cluster = cluster_loss_fn(gf.squeeze(3)) # loss = loss_seg + alfa * loss_cluster #pred_choice = torch.sigmoid(pred.view(-1,1)).data.round().type_as(target.data) #print('points:',points['streamlines']) #print('points shape:',points['streamlines'].shape) #print('streamlines:', data_dir = cfg['dataset_dir'] #streamlines, head, leng, idxs = load_streamlines(data['dir']+'/'+data['name']+'.trk') #print('tract:',len(streamlines)) #print('pred:',obj_pred_choice) #print('taget:',obj_target) #print('pred shape:',obj_pred_choice.shape) #print('target shape:',obj_target.shape) print('val max class red ', obj_pred_choice.max().item()) print('val min class pred ', obj_pred_choice.min().item()) y_pred = obj_pred_choice.cpu().numpy() np.save(data['dir'] + '/y_pred_sDEC_k5_16pts_nodropout', y_pred) y_test = obj_target.cpu().numpy() np.save(data['dir'] + '/y_test_sDEC_k5_16pts_nodropout', y_test) #np.save(data['dir']+'/streamlines_lstm_GIN',streamlines) correct = obj_pred_choice.eq(obj_target.data.int()).cpu().sum() acc = correct.item() / float(obj_target.size(0)) if num_classes > 2: iou, prec, recall = L.iou_multi( obj_pred_choice.data.int().cpu().numpy(), obj_target.data.int().cpu().numpy(), num_classes, multi_cat=cfg['multi_category']) assert (np.isnan(iou).sum() == 0) if cfg['multi_category']: s, n_parts = data['gt_offset'] e = s + n_parts iou[0, :s], prec[0, :s], recall[0, :s] = 0., 0., 0. iou[0, e:], prec[0, e:], recall[0, e:] = 0., 0., 0. iou = torch.from_numpy(iou).float() prec = torch.from_numpy(prec).float() recall = torch.from_numpy(recall).float() category = data['category'].squeeze().nonzero().float() iou = torch.cat([iou, category], 1) else: iou = torch.tensor([iou.mean()]) prec = torch.tensor([prec.mean()]) recall = torch.tensor([recall.mean()]) assert (torch.isnan(iou).sum() == 0) else: tp = torch.mul( obj_pred_choice.data, obj_target.data.int()).cpu().sum().item() + 0.00001 fp = obj_pred_choice.gt( obj_target.data.int()).cpu().sum().item() fn = obj_pred_choice.lt( obj_target.data.int()).cpu().sum().item() tn = correct.item() - tp iou = torch.tensor([float(tp) / (tp + fp + fn)]) prec = torch.tensor([float(tp) / (tp + fp)]) recall = torch.tensor([float(tp) / (tp + fn)]) mean_val_prec = torch.cat((mean_val_prec, prec), 0) mean_val_recall = torch.cat((mean_val_recall, recall), 0) mean_val_iou = torch.cat((mean_val_iou, iou), 0) mean_val_acc = torch.cat((mean_val_acc, torch.tensor([acc])), 0) print('VALIDATION [%d: %d/%d] val accuracy: %f' \ % (epoch, j, len(dataset), acc)) if cfg['save_pred'] and consumed: print('buffering prediction %s' % sample_name) if num_classes > 2: for cl in range(num_classes): sl_idx = np.where( obj_pred.data.cpu().view(-1).numpy() == cl)[0] if cl == 0: pred_buffer[sample_name] = [] pred_buffer[sample_name].append(sl_idx.tolist()) else: sl_idx = np.where( obj_pred.data.cpu().view(-1).numpy() == 1)[0] pred_buffer[sample_name] = sl_idx.tolist() #if cfg['save_softmax_out']: # if cfg['model'] in 'pointnet_mgfml': # if sample_name not in sm_buffer.keys(): # sm_buffer[sample_name] = [] # if classifier.feat.multi_feat > 1: # sm_buffer[sample_name].append( # classifier.feat.mf.softmax_out.cpu().numpy()) # if cfg['model'] == 'pointnet_mgfml': # for l in classifier.layers: # sm_buffer[sample_name].append( # l.mf.softmax_out.cpu().numpy()) # sm2_buffer[sample_name] = probas.cpu().numpy() #if cfg['save_gf']: # gf_buffer[sample_name] = np.unique( # classifier.feat.globalfeat.data.cpu().squeeze().numpy(), axis = 0) #gf_buffer[sample_name] = classifier.globalfeat #if cfg['save_embedding'] and consumed: # emb_buffer[sample_name] = obj_embedding if consumed: print(j) j += 1 if split_obj: consumed = False new_obj_read = True macro_iou = torch.mean(mean_val_iou) macro_prec = torch.mean(mean_val_prec) macro_recall = torch.mean(mean_val_recall) epoch_iou = macro_iou.item() if cfg['save_pred']: #os.system('rm -r %s/predictions_test*' % writer.logdir) pred_dir = writer.logdir + '/predictions_test_%d' % epoch if not os.path.exists(pred_dir): os.makedirs(pred_dir) print('saving files') for filename, value in pred_buffer.items(): with open(os.path.join(pred_dir, filename) + '.pkl', 'wb') as f: pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL) #if cfg['save_softmax_out']: # os.system('rm -r %s/sm_out_test*' % writer.logdir) # sm_dir = writer.logdir + '/sm_out_test_%d' % epoch # if not os.path.exists(sm_dir): # os.makedirs(sm_dir) # for filename, value in sm_buffer.iteritems(): # with open(os.path.join(sm_dir, filename) + '_sm_1.pkl', 'wb') as f: # pickle.dump( # value, f, protocol=pickle.HIGHEST_PROTOCOL) # for filename, value in sm2_buffer.iteritems(): # with open(os.path.join(sm_dir, filename) + '_sm_2.pkl', 'wb') as f: # pickle.dump( # value, f, protocol=pickle.HIGHEST_PROTOCOL) #if cfg['save_gf']: #os.system('rm -r %s/gf_test*' % writer.logdir) # gf_dir = writer.logdir + '/gf_test_%d' % epoch # if not os.path.exists(gf_dir): # os.makedirs(gf_dir) # i = 0 # for filename, value in gf_buffer.items(): # if i == 3: # break # with open(os.path.join(gf_dir, filename) + '.pkl', 'wb') as f: # pickle.dump(value, f, protocol=pickle.HIGHEST_PROTOCOL) #if cfg['save_embedding']: # print('saving embedding') # emb_dir = writer.logdir + '/embedding_test_%d' % epoch # if not os.path.exists(emb_dir): # os.makedirs(emb_dir) # for filename, value in emb_buffer.iteritems(): # np.save(os.path.join(emb_dir, filename), value.cpu().numpy()) if cfg['with_gt']: print('TEST ACCURACY: %f' % torch.mean(mean_val_acc).item()) print('TEST PRECISION: %f' % macro_prec.item()) print('TEST RECALL: %f' % macro_recall.item()) print('TEST IOU: %f' % macro_iou.item()) mean_val_dsc = mean_val_prec * mean_val_recall * 2 / (mean_val_prec + mean_val_recall) final_scores_file = writer.logdir + '/final_scores_test_%d.txt' % epoch scores_file = writer.logdir + '/scores_test_%d.txt' % epoch print('saving scores') with open(scores_file, 'w') as f: f.write('acc\n') f.writelines('%f\n' % v for v in mean_val_acc.tolist()) f.write('prec\n') f.writelines('%f\n' % v for v in mean_val_prec.tolist()) f.write('recall\n') f.writelines('%f\n' % v for v in mean_val_recall.tolist()) f.write('dsc\n') f.writelines('%f\n' % v for v in mean_val_dsc.tolist()) f.write('iou\n') f.writelines('%f\n' % v for v in mean_val_iou.tolist()) with open(final_scores_file, 'w') as f: f.write('acc\n') f.write('%f\n' % mean_val_acc.mean()) f.write('%f\n' % mean_val_acc.std()) f.write('prec\n') f.write('%f\n' % mean_val_prec.mean()) f.write('%f\n' % mean_val_prec.std()) f.write('recall\n') f.write('%f\n' % mean_val_recall.mean()) f.write('%f\n' % mean_val_recall.std()) f.write('dsc\n') f.write('%f\n' % mean_val_dsc.mean()) f.write('%f\n' % mean_val_dsc.std()) f.write('iou\n') f.write('%f\n' % mean_val_iou.mean()) f.write('%f\n' % mean_val_iou.std()) print('\n\n')
def criterion(self, logit, truth): logit = logit.squeeze(1) truth = truth.squeeze(1) loss = L.lovasz_hinge(logit, truth, per_image=True, ignore=None) return loss
train_loss = [] train_iou = [] model.train() with tqdm(train_loader) as pbar: for images, masks in pbar: masks = masks.cuda() y_pred = model(Variable(images).cuda()) prob = torch.sigmoid(y_pred).cpu().data.numpy() truth = masks.cpu().data.numpy() iou = do_kaggle_metric(prob, truth, threshold=0.5) train_iou.append(iou) loss = L.lovasz_hinge(y_pred.squeeze(), masks.squeeze().cuda(), per_image=True, ignore=None) train_loss.append(loss.item()) loss.backward() optimizer.step() optimizer.zero_grad() pbar.set_description("Loss: %.3f, IoU: %.3f, Progress" % (loss, iou)) val_loss = [] val_iou = [] model.eval() with tqdm(val_loader) as pbar: for images, masks in pbar: if len(images) == 2: image_ori, image_rev = images
def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): if not ModelPhase.is_valid_phase(phase): raise ValueError("ModelPhase {} is not valid!".format(phase)) if ModelPhase.is_train(phase): width = cfg.TRAIN_CROP_SIZE[0] height = cfg.TRAIN_CROP_SIZE[1] else: width = cfg.EVAL_CROP_SIZE[0] height = cfg.EVAL_CROP_SIZE[1] image_shape = [-1, cfg.DATASET.DATA_DIM, height, width] grt_shape = [-1, 1, height, width] class_num = cfg.DATASET.NUM_CLASSES with fluid.program_guard(main_prog, start_prog): with fluid.unique_name.guard(): # 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程 # 预测部署时只须对输入图像增加batch_size维度即可 if ModelPhase.is_predict(phase): if cfg.SLIM.PREPROCESS: image = fluid.data( name='image', shape=image_shape, dtype='float32') else: origin_image = fluid.data( name='image', shape=[-1, -1, -1, cfg.DATASET.DATA_DIM], dtype='float32') image, valid_shape, origin_shape = export_preprocess( origin_image) else: image = fluid.data( name='image', shape=image_shape, dtype='float32') label = fluid.data(name='label', shape=grt_shape, dtype='int32') mask = fluid.data(name='mask', shape=grt_shape, dtype='int32') # use DataLoader when doing traning and evaluation if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase): data_loader = fluid.io.DataLoader.from_generator( feed_list=[image, label, mask], capacity=cfg.DATALOADER.BUF_SIZE, iterable=False, use_double_buffer=True) loss_type = cfg.SOLVER.LOSS if not isinstance(loss_type, list): loss_type = list(loss_type) # lovasz_hinge_loss或dice_loss或bce_loss只适用两类分割中 if class_num > 2 and (("lovasz_hinge_loss" in loss_type) or ("dice_loss" in loss_type) or ("bce_loss" in loss_type)): raise Exception( "lovasz hinge loss, dice loss and bce loss are only applicable to binary classfication." ) # 在两类分割情况下,当loss函数选择lovasz_hinge_loss或dice_loss或bce_loss的时候,最后logit输出通道数设置为1 if ("dice_loss" in loss_type) or ("bce_loss" in loss_type) or ( "lovasz_hinge_loss" in loss_type): class_num = 1 if ("softmax_loss" in loss_type) or ( "lovasz_softmax_loss" in loss_type): raise Exception( "softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss." ) logits = seg_model(image, class_num) # 根据选择的loss函数计算相应的损失函数 if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase): loss_valid = False avg_loss_list = [] valid_loss = [] if "softmax_loss" in loss_type: weight = cfg.SOLVER.CROSS_ENTROPY_WEIGHT avg_loss_list.append( multi_softmax_with_loss(logits, label, mask, class_num, weight)) loss_valid = True valid_loss.append("softmax_loss") if "dice_loss" in loss_type: avg_loss_list.append(multi_dice_loss(logits, label, mask)) loss_valid = True valid_loss.append("dice_loss") if "bce_loss" in loss_type: avg_loss_list.append(multi_bce_loss(logits, label, mask)) loss_valid = True valid_loss.append("bce_loss") if "lovasz_hinge_loss" in loss_type: avg_loss_list.append( lovasz_hinge(logits, label, ignore=mask)) loss_valid = True valid_loss.append("lovasz_hinge_loss") if "lovasz_softmax_loss" in loss_type: probas = fluid.layers.softmax(logits, axis=1) avg_loss_list.append( lovasz_softmax(probas, label, ignore=mask)) loss_valid = True valid_loss.append("lovasz_softmax_loss") if not loss_valid: raise Exception( "SOLVER.LOSS: {} is set wrong. it should " "include one of (softmax_loss, bce_loss, dice_loss, lovasz_hinge_loss, lovasz_softmax_loss) at least" " example: ['softmax_loss'], ['dice_loss'], ['bce_loss', 'dice_loss'], ['lovasz_hinge_loss','bce_loss'], ['lovasz_softmax_loss','softmax_loss']" .format(cfg.SOLVER.LOSS)) invalid_loss = [x for x in loss_type if x not in valid_loss] if len(invalid_loss) > 0: print( "Warning: the loss {} you set is invalid. it will not be included in loss computed." .format(invalid_loss)) avg_loss = 0 for i in range(0, len(avg_loss_list)): loss_name = valid_loss[i].upper() loss_weight = eval('cfg.SOLVER.LOSS_WEIGHT.' + loss_name) avg_loss += loss_weight * avg_loss_list[i] #get pred result in original size if isinstance(logits, tuple): logit = logits[0] else: logit = logits if logit.shape[2:] != label.shape[2:]: logit = fluid.layers.resize_bilinear(logit, label.shape[2:]) # return image input and logit output for inference graph prune if ModelPhase.is_predict(phase): # 两类分割中,使用lovasz_hinge_loss或dice_loss或bce_loss返回的logit为单通道,进行到两通道的变换 if class_num == 1: logit = sigmoid_to_softmax(logit) else: logit = softmax(logit) # 获取有效部分 if cfg.SLIM.PREPROCESS: return image, logit else: logit = fluid.layers.slice( logit, axes=[2, 3], starts=[0, 0], ends=valid_shape) logit = fluid.layers.resize_bilinear( logit, out_shape=origin_shape, align_corners=False, align_mode=0) logit = fluid.layers.argmax(logit, axis=1) return origin_image, logit if class_num == 1: out = sigmoid_to_softmax(logit) out = fluid.layers.transpose(out, [0, 2, 3, 1]) else: out = fluid.layers.transpose(logit, [0, 2, 3, 1]) pred = fluid.layers.argmax(out, axis=3) pred = fluid.layers.unsqueeze(pred, axes=[3]) if ModelPhase.is_visual(phase): if class_num == 1: logit = sigmoid_to_softmax(logit) else: logit = softmax(logit) return pred, logit if ModelPhase.is_eval(phase): return data_loader, avg_loss, pred, label, mask if ModelPhase.is_train(phase): optimizer = solver.Solver(main_prog, start_prog) decayed_lr = optimizer.optimise(avg_loss) return data_loader, avg_loss, decayed_lr, pred, label, mask
def forward(self, input, target): loss = L.lovasz_hinge(input, target) return loss
# train net.train() for batch_image, batch_mask in tqdm(dataloader['train']): optimizer.zero_grad() # pdb.set_trace() batch_image = batch_image.cuda() batch_mask = batch_mask.cuda() with torch.set_grad_enabled(True): outputs = net(batch_image).squeeze(dim=1) if epoch < 120: loss = criterion(outputs, batch_mask) else: loss = lovasz_hinge(outputs, batch_mask) loss.backward() optimizer.step() train_running_corrects += torch.sum( (outputs > 0.5) == (batch_mask > 0.5)).item() train_running_loss += loss.item() * batch_image.size(0) train_running_dice_loss += dice_loss( outputs, batch_mask).item() * batch_image.size(0) train_iou += iou(outputs, batch_mask) * batch_image.size(0) # val net.eval() for batch_image, batch_mask in tqdm(dataloader['val']): batch_image = batch_image.cuda() batch_mask = batch_mask.cuda()
def forward(self, input, target): input = input.squeeze(1) target = target.squeeze(1) loss = lovasz_hinge(input, target, per_image=True) return loss
def symmetric_lovasz(outputs, targets): return (lovasz_hinge(outputs, targets) + lovasz_hinge(-outputs, 1 - targets)) / 2
def forward(self, input, target): # loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target)) loss = L.lovasz_hinge(input, target) return loss.mean()
def unet_train(): batch_size = 1 num_epochs = [5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000] num_workers = 2 lr = 0.0001 losslist = ['dice'] # ['focal', 'bce', 'dice', 'lovasz'] optimlist = ['adam'] # ['adam', 'sgd'] iflog = True SC_root_dir = '../dataset-EdmSealedCrack-512' train_files, val_files, test_files = myutils.organize_SC_files(SC_root_dir) train_RC_dataset = DatasetRealCrack('../dataset-EdmCrack600-512/A/train', transform=transform) train_SC_dataset = DatasetSealedCrack(files=train_files, root_dir=SC_root_dir, transform=data_Train_transforms) val_RC_dataset = DatasetRealCrack('../dataset-EdmCrack600-512/A/val', transform=transform) val_SC_dataset = DatasetSealedCrack(files=val_files, root_dir=SC_root_dir, transform=data_Test_transforms) train_loader = torch.utils.data.DataLoader(ConcatDataset( train_RC_dataset, train_SC_dataset), batch_size=2, shuffle=True, num_workers=2) criterion = nn.BCELoss() focallos = FocalLoss(gamma=2) doubleFocalloss = focalloss.FocalLoss_2_datasets(gamma=2) epoidx = -1 for los in losslist: for opt in optimlist: start = time.time() print(los, opt) torch.manual_seed(77) torch.cuda.manual_seed(77) ################# #unet = Unet_SpatialPyramidPooling(3).cuda() ################# unet = Unet(3).cuda() SC_classifier = classifier(64, 2).cuda() RC_classifier = classifier(64, 2).cuda() ################## #unet = smp.Unet('resnet34', encoder_weights='imagenet').cuda() #unet.segmentation_head = torch.nn.Sequential().cuda() #SC_classifier = classifier(16, 2).cuda() #RC_classifier = classifier(16, 2).cuda() #UNCOMMENT TO KEEP TRAINING THE BEST MODEL prev_epoch = 0 # if loading model 58, change to prev_epoch = 58. When saving the model, it is going to be named as 59, 60, 61... #unet.load_state_dict(torch.load('trained_models/unet_adam_dice_58.pkl')) #SC_classifier.load_state_dict(torch.load('trained_models/SC_classifier_adam_dice_58.pkl')) #RC_classifier.load_state_dict(torch.load('trained_models/RC_classifier_adam_dice_58.pkl')) history = [] if 'adam' in opt: optimizer = torch.optim.Adam(unet.parameters(), lr=lr) elif 'sgd' in opt: optimizer = torch.optim.SGD(unet.parameters(), lr=10 * lr, momentum=0.9) logging.basicConfig(filename='./logs/logger_unet.log', level=logging.INFO) total_step = len(train_loader) epoidx += 1 for epoch in range(num_epochs[epoidx]): totalloss = 0 for i, (realCrack_batch, sealedCrack_batch) in enumerate(train_loader): SC_images = sealedCrack_batch[0].cuda() SC_masks = sealedCrack_batch[1].cuda() RC_images = realCrack_batch[0].cuda() RC_masks = realCrack_batch[1].cuda() SC_encoder = unet(SC_images) RC_encoder = unet(RC_images) ############# SC_outputs = SC_classifier(SC_encoder) RC_outputs = RC_classifier(RC_encoder) ############# #Deep lab v3 #SC_outputs = SC_classifier(SC_encoder['out']) #RC_outputs = RC_classifier(RC_encoder['out']) ############## if 'bce' in los: masks = onehot(masks) loss = criterion(outputs, masks) elif 'dice' in los: branch_RC = {'outputs': RC_outputs, 'masks': RC_masks} branch_SC = {'outputs': SC_outputs, 'masks': SC_masks} loss = dice_loss_2_datasets(branch_RC, branch_SC) #masks = onehot(masks) #loss = dice_loss(outputs, masks) elif 'lovasz' in los: masks = onehot(masks) loss = L.lovasz_hinge(outputs, masks) elif 'focal' in los: #loss = focallos(outputs, masks.long()) branch_RC = { 'outputs': RC_outputs, 'masks': RC_masks.long() } branch_SC = { 'outputs': SC_outputs, 'masks': SC_masks.long() } loss = doubleFocalloss(branch_RC, branch_SC) totalloss += loss * RC_images.size(0) #*2? #print(RC_images.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() if i % 10 == 0: print(epoch, i) print("total loss: ", totalloss) if i % 1000 == 0: print("Epoch:%d; Iteration:%d; Loss:%f" % (epoch, i, loss)) if i + 1 == total_step: # and epoch%1==0: #and val_miou>0.85: torch.save( unet.state_dict(), './trained_models/unet_' + opt + '_' + los + '_' + str(epoch + 1 + prev_epoch) + '.pkl') torch.save( RC_classifier.state_dict(), './trained_models/RC_classifier_' + opt + '_' + los + '_' + str(epoch + 1 + prev_epoch) + '.pkl') torch.save( SC_classifier.state_dict(), './trained_models/SC_classifier_' + opt + '_' + los + '_' + str(epoch + 1 + prev_epoch) + '.pkl') history_np = np.array(history) np.save('./logs/unet_' + opt + '_' + los + '.npy', history_np) end = time.time() print((end - start) / 60)