def test_dice_flat(self): dice_loss = DiceLoss() x = torch.FloatTensor([[0., 1.], [1., 0.]]) y = torch.FloatTensor([[0., 1.], [1., 0.]]) dice = dice_loss(x, y) print('DICE', dice) self.assertTrue(torch.eq(dice, 0))
def test_weighted_dice(self): loss = DiceLoss() weighted_loss = WeightedLoss(loss) x = torch.FloatTensor([[0., 1.], [1., 1.]]) y = torch.FloatTensor([[0., 0.], [1., 1.]]) w = torch.FloatTensor([[0.25, 0.25], [0., 0.]]) score = weighted_loss(x, y, w) self.assertEqual(round(score.item(), 2), 0.26)
def test_weighted_dice(self): sub_loss = DiceLoss() loss = WeightedLoss(sub_loss) x = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64) y = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64) w = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64) self.assertTrue( torch.autograd.gradcheck(loss, (x, y, w), raise_exception=False))
def __init__(self, opt): super().__init__(opt) net = UNet3D(opt) self.net = self.to_gpu(net) self.optimizer = Adam(net.parameters(), lr=opt.lr, betas=tuple(opt.betas), weight_decay=opt.weight_decay) # self.lr_scheduler = ReduceLROnPlateau() self.loss_fn = DiceLoss(opt)
def main(config, resume): torch.manual_seed(42) train_logger = Logger() # DATA LOADERS # config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples'] # config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples'] config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables'] supervised_loader = dataloaders.LiTS(config['train_supervised']) unsupervised_loader = dataloaders.LiTS(config['train_unsupervised']) val_loader = dataloaders.LiTS(config['val_loader']) iter_per_epoch = len(unsupervised_loader) # SUPERVISED LOSS if config['model']['sup_loss'] == 'CE': sup_loss = CE_loss elif config['model']['sup_loss'] == 'FL': alpha = get_alpha(supervised_loader) # calculare class occurences sup_loss = FocalLoss(apply_nonlin=softmax_helper, alpha=alpha, gamma=2, smooth=1e-5) elif config['model']['sup_loss'] == 'DC': sup_loss = DiceLoss(val_loader.dataset.num_classes) else: sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch, epochs=config['trainer']['epochs'], num_classes=val_loader.dataset.num_classes) # MODEL rampup_ends = int(config['ramp_up'] * config['trainer']['epochs']) cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader), rampup_ends=rampup_ends) model = models.CCT(num_classes=val_loader.dataset.num_classes, conf=config['model'], sup_loss=sup_loss, cons_w_unsup=cons_w_unsup, weakly_loss_w=config['weakly_loss_w'], use_weak_lables=config['use_weak_lables']) # ignore_index=val_loader.dataset.ignore_index) print(f'\n{model}\n') # TRAINING trainer = Trainer(model=model, resume=resume, config=config, supervised_loader=supervised_loader, unsupervised_loader=unsupervised_loader, val_loader=val_loader, iter_per_epoch=iter_per_epoch, train_logger=train_logger) trainer.train()
def train(): net = CSNet3D(classes=2, channels=1).cuda() net = nn.DataParallel(net, device_ids=[0, 1]).cuda() optimizer = optim.Adam(net.parameters(), lr=args['lr'], weight_decay=0.0005) # load train dataset train_data = Data(args['data_path'], train=True) batchs_data = DataLoader(train_data, batch_size=args['batch_size'], num_workers=4, shuffle=True) critrion2 = WeightedCrossEntropyLoss().cuda() critrion = nn.CrossEntropyLoss().cuda() critrion3 = DiceLoss().cuda() # Start training print("\033[1;30;44m {} Start training ... {}\033[0m".format( "*" * 8, "*" * 8)) iters = 1 for epoch in range(args['epochs']): net.train() for idx, batch in enumerate(batchs_data): image = batch[0].cuda() label = batch[1].cuda() optimizer.zero_grad() pred = net(image) loss_dice = critrion3(pred, label) label = label.squeeze(1) loss_ce = critrion(pred, label) loss_wce = critrion2(pred, label) loss = (loss_ce + 0.6 * loss_wce + 0.4 * loss_dice) / 3 loss.backward() optimizer.step() tp, fn, fp, iou = metrics3d(pred, label, pred.shape[0]) if (epoch % 2) == 0: print( '\033[1;36m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} ' .format(epoch + 1, iters, loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], iou / pred.shape[0])) else: print( '\033[1;32m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} ' .format(epoch + 1, iters, loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], iou / pred.shape[0])) iters += 1 # # ---------------------------------- visdom -------------------------------------------------- X, x_tp, x_fn, x_fp, x_dc = iters, iters, iters, iters, iters Y, y_tp, y_fn, y_fp, y_dc = loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], iou / \ pred.shape[0] update_lines(env, panel, X, Y) update_lines(env1, panel1, x_tp, y_tp) update_lines(env2, panel2, x_fn, y_fn) update_lines(env3, panel3, x_fp, y_fp) update_lines(env6, panel6, x_dc, y_dc) # # -------------------------------------------------------------------------------------------- adjust_lr(optimizer, base_lr=args['lr'], iter=epoch, max_iter=args['epochs'], power=0.9) if (epoch + 1) % args['snapshot'] == 0: save_ckpt(net, str(epoch + 1)) # model eval if (epoch + 1) % args['test_step'] == 0: test_tp, test_fn, test_fp, test_dc = model_eval( net, critrion, iters) print( "Average TP:{0:.4f}, average FN:{1:.4f}, average FP:{2:.4f}". format(test_tp, test_fn, test_fp))
def get_loss_function(opt): if opt.loss == 'dice': return DiceLoss(sigmoid_normalization=True, weight=opt.class_weights) else: raise ValueError("Only 'dice' loss is supported now.")
classes = open(args.classes, 'r').read().splitlines() val_images = glob.glob(os.path.normpath(args.val_image_path) + '/*.jpg') val_masks = glob.glob(os.path.normpath(args.val_label_path) + '/*.png') val_images.sort() val_masks.sort() if args.backbone == 'resnet50': model = Deeplabv3Resnet50(len(classes)).to(device) else: model = Deeplabv3Resnet101(len(classes)).to(device) model.load_state_dict(torch.load(args.pt)) model = model.eval() dice_loss = DiceLoss() iou_metric = IoU() accuracy_metric = Accuracy() precision_metric = Precision() recall_metric = Recall() f_score_metric = Fscore() val_dataset = SegmentationDataset(val_images, val_masks, classes, args.size, False) val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers) sum_losses = 0 sum_iou_metric = 0 sum_accuracy_metric = 0 sum_precision_metric = 0 sum_recall_metric = 0
def test_diff_dice(self): loss = DiceLoss() x = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64) y = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64) self.assertTrue( torch.autograd.gradcheck(loss, (x, y), raise_exception=False))
def criterion(pred, label): # return loss_fn(pred, label) + DiceLoss()(pred, label) # return nn.L1Loss()(pred, label) # return DiceLoss()(pred, label) return nn.BCELoss()(pred, label) + DiceLoss()(pred, label)
image_a = data['A'][2].cuda() target_a = data['A'][1].cuda() ctr_a = data['A'][3].cuda() edt_a = data['A'][4].cuda() # data B # image_b = data['B'][2].cuda() optimiser.zero_grad() a1, a2, a3, a4, a5 = net.downsample(image_a) pred_seg_a, pred_ctr_a, pred_edt_a, _ = net.upsample( a1, a2, a3, a4, a5) loss_seg_a = criterion(pred_seg_a, target_a) loss_ctr_a = DiceLoss()(pred_ctr_a, ctr_a) loss_edt_a = nn.L1Loss()(pred_edt_a, edt_a) loss = loss_seg_a + loss_ctr_a + loss_edt_a loss.backward() # loss_seg_a.backward() optimiser.step() # dice_score = dice_coeff(torch.round(pred), l).item() # epoch_train_loss_rec.append(loss_recon.item()) epoch_train_loss_seg.append(loss_seg_a.item()) # mean_loss_rec = np.mean(epoch_train_loss_rec) mean_loss_seg = np.mean(epoch_train_loss_seg)
def build_dice_critn(C): from utils.losses import DiceLoss return DiceLoss()
def criterion_seg(self, prediction, target): return nn.BCELoss()(prediction, target) + DiceLoss()(prediction, target)
def criterion(pred, label): # return symmetric_lovasz(pred, label) return nn.BCELoss()(pred, label) + DiceLoss()(pred, label)
def criterion_seg(pred, label): return nn.BCELoss()(pred, label) + DiceLoss()(pred, label)
def train_main(cfg): ''' 训练的主函数 :param cfg: 配置 :return: ''' # config train_cfg = cfg.train_cfg dataset_cfg = cfg.dataset_cfg model_cfg = cfg.model_cfg is_parallel = cfg.setdefault(key='is_parallel', default=False) device = cfg.device is_online_train = cfg.setdefault(key='is_online_train', default=False) # 配置logger logging.basicConfig(filename=cfg.logfile, filemode='a', level=logging.INFO, format='%(asctime)s\n%(message)s', datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger() # # 构建数据集 train_dataset = LandDataset(DIR_list=dataset_cfg.train_dir_list, mode='train', input_channel=dataset_cfg.input_channel, transform=dataset_cfg.train_transform) split_val_from_train_ratio = dataset_cfg.setdefault( key='split_val_from_train_ratio', default=None) if split_val_from_train_ratio is None: val_dataset = LandDataset(DIR_list=dataset_cfg.val_dir_list, mode='val', input_channel=dataset_cfg.input_channel, transform=dataset_cfg.val_transform) else: val_size = int(len(train_dataset) * split_val_from_train_ratio) train_size = len(train_dataset) - val_size train_dataset, val_dataset = random_split( train_dataset, [train_size, val_size], generator=torch.manual_seed(cfg.random_seed)) # val_dataset.dataset.transform = dataset_cfg.val_transform # 要配置一下val的transform print(f"按照{split_val_from_train_ratio}切分训练集...") # 构建dataloader def _init_fn(): np.random.seed(cfg.random_seed) train_dataloader = DataLoader(train_dataset, batch_size=train_cfg.batch_size, shuffle=True, num_workers=train_cfg.num_workers, drop_last=True, worker_init_fn=_init_fn()) val_dataloader = DataLoader(val_dataset, batch_size=train_cfg.batch_size, num_workers=train_cfg.num_workers, shuffle=False, drop_last=True, worker_init_fn=_init_fn()) # 构建模型 if train_cfg.is_swa: model = torch.load(train_cfg.check_point_file, map_location=device).to( device) # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上 swa_model = torch.load( train_cfg.check_point_file, map_location=device).to( device) # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上 if is_parallel: model = torch.nn.DataParallel(model) swa_model = torch.nn.DataParallel(swa_model) swa_n = 0 parameters = swa_model.parameters() else: model = build_model(model_cfg).to(device) if is_parallel: model = torch.nn.DataParallel(model) parameters = model.parameters() # 定义优化器 optimizer_cfg = train_cfg.optimizer_cfg lr_scheduler_cfg = train_cfg.lr_scheduler_cfg if optimizer_cfg.type == 'adam': optimizer = optim.Adam(params=parameters, lr=optimizer_cfg.lr, weight_decay=optimizer_cfg.weight_decay) elif optimizer_cfg.type == 'adamw': optimizer = optim.AdamW(params=parameters, lr=optimizer_cfg.lr, weight_decay=optimizer_cfg.weight_decay) elif optimizer_cfg.type == 'sgd': optimizer = optim.SGD(params=parameters, lr=optimizer_cfg.lr, momentum=optimizer_cfg.momentum, weight_decay=optimizer_cfg.weight_decay) elif optimizer_cfg.type == 'RMS': optimizer = optim.RMSprop(params=parameters, lr=optimizer_cfg.lr, weight_decay=optimizer_cfg.weight_decay) else: raise Exception('没有该优化器!') if not lr_scheduler_cfg: lr_scheduler = None elif lr_scheduler_cfg.policy == 'cos': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, lr_scheduler_cfg.T_0, lr_scheduler_cfg.T_mult, lr_scheduler_cfg.eta_min, last_epoch=lr_scheduler_cfg.last_epoch) elif lr_scheduler_cfg.policy == 'LambdaLR': import math lf = lambda x: (((1 + math.cos(x * math.pi / train_cfg.num_epochs)) / 2 )**1.0) * 0.95 + 0.05 # cosine lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) lr_scheduler.last_epoch = 0 else: lr_scheduler = None # 定义损失函数 DiceLoss_fn = DiceLoss(mode='multiclass') SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1) loss_func = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn, first_weight=0.5, second_weight=0.5).cuda() # loss_cls_func = torch.nn.BCEWithLogitsLoss() # 创建保存模型的文件夹 check_point_dir = '/'.join(model_cfg.check_point_file.split('/')[:-1]) if not os.path.exists(check_point_dir): # 如果文件夹不存在就创建 os.mkdir(check_point_dir) # 开始训练 auto_save_epoch_list = train_cfg.setdefault(key='auto_save_epoch_list', default=5) # 每隔几轮保存一次模型,默认为5 train_loss_list = [] val_loss_list = [] val_loss_min = 999999 best_epoch = 0 best_miou = 0 train_loss = 10 # 设置一个初始值 logger.info('开始在{}上训练{}模型...'.format(device, model_cfg.type)) logger.info('补充信息:{}\n'.format(cfg.setdefault(key='info', default='None'))) for epoch in range(train_cfg.num_epochs): print() print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) start_time = time.time() print(f"正在进行第{epoch}轮训练...") logger.info('*' * 10 + f"第{epoch}轮" + '*' * 10) # # 训练一轮 if train_cfg.is_swa: # swa训练方式 train_loss = train_epoch(swa_model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device) moving_average(model, swa_model, 1.0 / (swa_n + 1)) swa_n += 1 bn_update(train_dataloader, model, device) else: train_loss = train_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device) # train_loss = train_unet3p_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device) # # 在训练集上评估模型 # val_loss, val_miou = evaluate_unet3p_model(model, val_dataset, loss_func, device, # cfg.num_classes, train_cfg.num_workers, batch_size=train_cfg.batch_size) if not is_online_train: # 只有在线下训练的时候才需要评估模型 val_loss, val_miou = evaluate_model(model, val_dataloader, loss_func, device, cfg.num_classes) else: val_loss = 0 val_miou = 0 train_loss_list.append(train_loss) val_loss_list.append(val_loss) # 保存模型 if not is_online_train: # 非线上训练时需要保存best model if val_loss < val_loss_min: val_loss_min = val_loss best_epoch = epoch best_miou = val_miou if is_parallel: torch.save(model.module, model_cfg.check_point_file) else: torch.save(model, model_cfg.check_point_file) if epoch in auto_save_epoch_list: # 如果再需要保存的轮次中,则保存 model_file = model_cfg.check_point_file.split( '.pth')[0] + '-epoch{}.pth'.format(epoch) if is_parallel: torch.save(model.module, model_file) else: torch.save(model, model_file) # 打印中间结果 end_time = time.time() run_time = int(end_time - start_time) m, s = divmod(run_time, 60) time_str = "{:02d}分{:02d}秒".format(m, s) print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) out_str = "第{}轮训练完成,耗时{},\t训练集上的loss={:.6f};\t验证集上的loss={:.4f},mIoU={:.6f}\t最好的结果是第{}轮,mIoU={:.6f}" \ .format(epoch, time_str, train_loss, val_loss, val_miou, best_epoch, best_miou) # out_str = "第{}轮训练完成,耗时{},\n训练集上的segm_loss={:.6f},cls_loss{:.6f}\n验证集上的segm_loss={:.4f},cls_loss={:.4f},mIoU={:.6f}\n最好的结果是第{}轮,mIoU={:.6f}" \ # .format(epoch, time_str, train_loss, train_cls_loss, val_loss, val_cls_loss, val_miou, best_epoch, # best_miou) print(out_str) logger.info(out_str + '\n')
def test_dice_negatives(self): dice_loss = DiceLoss() x = torch.FloatTensor([[1., 1.], [1., 0.]]) y = torch.FloatTensor([[0., 1.], [1., 0.]]) dice = dice_loss(x, y) self.assertEqual(round(dice.item(), 2), 0.2)