Esempio n. 1
0
 def test_dice_flat(self):
     dice_loss = DiceLoss()
     x = torch.FloatTensor([[0., 1.], [1., 0.]])
     y = torch.FloatTensor([[0., 1.], [1., 0.]])
     dice = dice_loss(x, y)
     print('DICE', dice)
     self.assertTrue(torch.eq(dice, 0))
Esempio n. 2
0
 def test_weighted_dice(self):
     loss = DiceLoss()
     weighted_loss = WeightedLoss(loss)
     x = torch.FloatTensor([[0., 1.], [1., 1.]])
     y = torch.FloatTensor([[0., 0.], [1., 1.]])
     w = torch.FloatTensor([[0.25, 0.25], [0., 0.]])
     score = weighted_loss(x, y, w)
     self.assertEqual(round(score.item(), 2), 0.26)
Esempio n. 3
0
 def test_weighted_dice(self):
     sub_loss = DiceLoss()
     loss = WeightedLoss(sub_loss)
     x = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64)
     y = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64)
     w = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64)
     self.assertTrue(
         torch.autograd.gradcheck(loss, (x, y, w), raise_exception=False))
Esempio n. 4
0
 def __init__(self, opt):
     super().__init__(opt)
     net = UNet3D(opt)
     self.net = self.to_gpu(net)
     self.optimizer = Adam(net.parameters(),
                           lr=opt.lr,
                           betas=tuple(opt.betas),
                           weight_decay=opt.weight_decay)
     # self.lr_scheduler = ReduceLROnPlateau()
     self.loss_fn = DiceLoss(opt)
Esempio n. 5
0
def main(config, resume):
    torch.manual_seed(42)
    train_logger = Logger()

    # DATA LOADERS
    # config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
    # config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
    config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
    supervised_loader = dataloaders.LiTS(config['train_supervised'])
    unsupervised_loader = dataloaders.LiTS(config['train_unsupervised'])
    val_loader = dataloaders.LiTS(config['val_loader'])
    iter_per_epoch = len(unsupervised_loader)

    # SUPERVISED LOSS
    if config['model']['sup_loss'] == 'CE':
        sup_loss = CE_loss
    elif config['model']['sup_loss'] == 'FL':
        alpha = get_alpha(supervised_loader)  # calculare class occurences
        sup_loss = FocalLoss(apply_nonlin=softmax_helper,
                             alpha=alpha,
                             gamma=2,
                             smooth=1e-5)
    elif config['model']['sup_loss'] == 'DC':
        sup_loss = DiceLoss(val_loader.dataset.num_classes)
    else:
        sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch,
                             epochs=config['trainer']['epochs'],
                             num_classes=val_loader.dataset.num_classes)

    # MODEL
    rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
    cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'],
                                      iters_per_epoch=len(unsupervised_loader),
                                      rampup_ends=rampup_ends)

    model = models.CCT(num_classes=val_loader.dataset.num_classes,
                       conf=config['model'],
                       sup_loss=sup_loss,
                       cons_w_unsup=cons_w_unsup,
                       weakly_loss_w=config['weakly_loss_w'],
                       use_weak_lables=config['use_weak_lables'])
    # ignore_index=val_loader.dataset.ignore_index)
    print(f'\n{model}\n')

    # TRAINING
    trainer = Trainer(model=model,
                      resume=resume,
                      config=config,
                      supervised_loader=supervised_loader,
                      unsupervised_loader=unsupervised_loader,
                      val_loader=val_loader,
                      iter_per_epoch=iter_per_epoch,
                      train_logger=train_logger)

    trainer.train()
Esempio n. 6
0
def train():
    net = CSNet3D(classes=2, channels=1).cuda()
    net = nn.DataParallel(net, device_ids=[0, 1]).cuda()
    optimizer = optim.Adam(net.parameters(),
                           lr=args['lr'],
                           weight_decay=0.0005)

    # load train dataset
    train_data = Data(args['data_path'], train=True)
    batchs_data = DataLoader(train_data,
                             batch_size=args['batch_size'],
                             num_workers=4,
                             shuffle=True)

    critrion2 = WeightedCrossEntropyLoss().cuda()
    critrion = nn.CrossEntropyLoss().cuda()
    critrion3 = DiceLoss().cuda()
    # Start training
    print("\033[1;30;44m {} Start training ... {}\033[0m".format(
        "*" * 8, "*" * 8))

    iters = 1
    for epoch in range(args['epochs']):
        net.train()
        for idx, batch in enumerate(batchs_data):
            image = batch[0].cuda()
            label = batch[1].cuda()
            optimizer.zero_grad()
            pred = net(image)
            loss_dice = critrion3(pred, label)
            label = label.squeeze(1)
            loss_ce = critrion(pred, label)
            loss_wce = critrion2(pred, label)
            loss = (loss_ce + 0.6 * loss_wce + 0.4 * loss_dice) / 3
            loss.backward()
            optimizer.step()
            tp, fn, fp, iou = metrics3d(pred, label, pred.shape[0])
            if (epoch % 2) == 0:
                print(
                    '\033[1;36m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} '
                    .format(epoch + 1, iters, loss.item(), tp / pred.shape[0],
                            fn / pred.shape[0], fp / pred.shape[0],
                            iou / pred.shape[0]))
            else:
                print(
                    '\033[1;32m [{0:d}:{1:d}] \u2501\u2501\u2501 loss:{2:.10f}\tTP:{3:.4f}\tFN:{4:.4f}\tFP:{5:.4f}\tIoU:{6:.4f} '
                    .format(epoch + 1, iters, loss.item(), tp / pred.shape[0],
                            fn / pred.shape[0], fp / pred.shape[0],
                            iou / pred.shape[0]))

            iters += 1
            # # ---------------------------------- visdom --------------------------------------------------
            X, x_tp, x_fn, x_fp, x_dc = iters, iters, iters, iters, iters
            Y, y_tp, y_fn, y_fp, y_dc = loss.item(), tp / pred.shape[0], fn / pred.shape[0], fp / pred.shape[0], iou / \
                                        pred.shape[0]

            update_lines(env, panel, X, Y)
            update_lines(env1, panel1, x_tp, y_tp)
            update_lines(env2, panel2, x_fn, y_fn)
            update_lines(env3, panel3, x_fp, y_fp)
            update_lines(env6, panel6, x_dc, y_dc)

            # # --------------------------------------------------------------------------------------------

        adjust_lr(optimizer,
                  base_lr=args['lr'],
                  iter=epoch,
                  max_iter=args['epochs'],
                  power=0.9)

        if (epoch + 1) % args['snapshot'] == 0:
            save_ckpt(net, str(epoch + 1))

        # model eval
        if (epoch + 1) % args['test_step'] == 0:
            test_tp, test_fn, test_fp, test_dc = model_eval(
                net, critrion, iters)
            print(
                "Average TP:{0:.4f}, average FN:{1:.4f},  average FP:{2:.4f}".
                format(test_tp, test_fn, test_fp))
Esempio n. 7
0
def get_loss_function(opt):
    if opt.loss == 'dice':
        return DiceLoss(sigmoid_normalization=True, weight=opt.class_weights)
    else:
        raise ValueError("Only 'dice' loss is supported now.")
Esempio n. 8
0
    classes = open(args.classes, 'r').read().splitlines()

    val_images = glob.glob(os.path.normpath(args.val_image_path) + '/*.jpg')
    val_masks = glob.glob(os.path.normpath(args.val_label_path) + '/*.png')
    val_images.sort()
    val_masks.sort()

    if args.backbone == 'resnet50':
        model = Deeplabv3Resnet50(len(classes)).to(device)
    else:
        model = Deeplabv3Resnet101(len(classes)).to(device)

    model.load_state_dict(torch.load(args.pt))
    model = model.eval()

    dice_loss = DiceLoss()

    iou_metric = IoU()
    accuracy_metric = Accuracy()
    precision_metric = Precision()
    recall_metric = Recall()
    f_score_metric = Fscore()

    val_dataset = SegmentationDataset(val_images, val_masks, classes, args.size, False)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers)

    sum_losses = 0
    sum_iou_metric = 0
    sum_accuracy_metric = 0
    sum_precision_metric = 0
    sum_recall_metric = 0
Esempio n. 9
0
 def test_diff_dice(self):
     loss = DiceLoss()
     x = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64)
     y = torch.ones((1, 1, 10, 10), requires_grad=True, dtype=torch.float64)
     self.assertTrue(
         torch.autograd.gradcheck(loss, (x, y), raise_exception=False))
Esempio n. 10
0
def criterion(pred, label):
    # return loss_fn(pred, label) + DiceLoss()(pred, label)
    # return nn.L1Loss()(pred, label)
    # return DiceLoss()(pred, label)
    return nn.BCELoss()(pred, label) + DiceLoss()(pred, label)
Esempio n. 11
0
        image_a = data['A'][2].cuda()
        target_a = data['A'][1].cuda()
        ctr_a = data['A'][3].cuda()
        edt_a = data['A'][4].cuda()

        # data B
        # image_b = data['B'][2].cuda()

        optimiser.zero_grad()

        a1, a2, a3, a4, a5 = net.downsample(image_a)
        pred_seg_a, pred_ctr_a, pred_edt_a, _ = net.upsample(
            a1, a2, a3, a4, a5)

        loss_seg_a = criterion(pred_seg_a, target_a)
        loss_ctr_a = DiceLoss()(pred_ctr_a, ctr_a)
        loss_edt_a = nn.L1Loss()(pred_edt_a, edt_a)

        loss = loss_seg_a + loss_ctr_a + loss_edt_a

        loss.backward()
        # loss_seg_a.backward()
        optimiser.step()

        # dice_score = dice_coeff(torch.round(pred), l).item()
        # epoch_train_loss_rec.append(loss_recon.item())
        epoch_train_loss_seg.append(loss_seg_a.item())

    # mean_loss_rec = np.mean(epoch_train_loss_rec)
    mean_loss_seg = np.mean(epoch_train_loss_seg)
Esempio n. 12
0
def build_dice_critn(C):
    from utils.losses import DiceLoss
    return DiceLoss()
Esempio n. 13
0
 def criterion_seg(self, prediction, target):
     return nn.BCELoss()(prediction, target) + DiceLoss()(prediction, target)
Esempio n. 14
0
def criterion(pred, label):
    # return symmetric_lovasz(pred, label)
    return nn.BCELoss()(pred, label) + DiceLoss()(pred, label)
Esempio n. 15
0
def criterion_seg(pred, label):
    return nn.BCELoss()(pred, label) + DiceLoss()(pred, label)
Esempio n. 16
0
def train_main(cfg):
    '''
    训练的主函数
    :param cfg: 配置
    :return:
    '''

    # config
    train_cfg = cfg.train_cfg
    dataset_cfg = cfg.dataset_cfg
    model_cfg = cfg.model_cfg
    is_parallel = cfg.setdefault(key='is_parallel', default=False)
    device = cfg.device
    is_online_train = cfg.setdefault(key='is_online_train', default=False)

    # 配置logger
    logging.basicConfig(filename=cfg.logfile,
                        filemode='a',
                        level=logging.INFO,
                        format='%(asctime)s\n%(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    logger = logging.getLogger()

    #
    # 构建数据集
    train_dataset = LandDataset(DIR_list=dataset_cfg.train_dir_list,
                                mode='train',
                                input_channel=dataset_cfg.input_channel,
                                transform=dataset_cfg.train_transform)
    split_val_from_train_ratio = dataset_cfg.setdefault(
        key='split_val_from_train_ratio', default=None)
    if split_val_from_train_ratio is None:
        val_dataset = LandDataset(DIR_list=dataset_cfg.val_dir_list,
                                  mode='val',
                                  input_channel=dataset_cfg.input_channel,
                                  transform=dataset_cfg.val_transform)
    else:
        val_size = int(len(train_dataset) * split_val_from_train_ratio)
        train_size = len(train_dataset) - val_size
        train_dataset, val_dataset = random_split(
            train_dataset, [train_size, val_size],
            generator=torch.manual_seed(cfg.random_seed))
        # val_dataset.dataset.transform = dataset_cfg.val_transform # 要配置一下val的transform
        print(f"按照{split_val_from_train_ratio}切分训练集...")

    # 构建dataloader
    def _init_fn():
        np.random.seed(cfg.random_seed)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_cfg.batch_size,
                                  shuffle=True,
                                  num_workers=train_cfg.num_workers,
                                  drop_last=True,
                                  worker_init_fn=_init_fn())
    val_dataloader = DataLoader(val_dataset,
                                batch_size=train_cfg.batch_size,
                                num_workers=train_cfg.num_workers,
                                shuffle=False,
                                drop_last=True,
                                worker_init_fn=_init_fn())

    # 构建模型
    if train_cfg.is_swa:
        model = torch.load(train_cfg.check_point_file, map_location=device).to(
            device)  # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上
        swa_model = torch.load(
            train_cfg.check_point_file, map_location=device).to(
                device)  # device参数传在里面,不然默认是先加载到cuda:0,to之后再加载到相应的device上
        if is_parallel:
            model = torch.nn.DataParallel(model)
            swa_model = torch.nn.DataParallel(swa_model)
        swa_n = 0
        parameters = swa_model.parameters()
    else:
        model = build_model(model_cfg).to(device)
        if is_parallel:
            model = torch.nn.DataParallel(model)
        parameters = model.parameters()

    # 定义优化器
    optimizer_cfg = train_cfg.optimizer_cfg
    lr_scheduler_cfg = train_cfg.lr_scheduler_cfg
    if optimizer_cfg.type == 'adam':
        optimizer = optim.Adam(params=parameters,
                               lr=optimizer_cfg.lr,
                               weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'adamw':
        optimizer = optim.AdamW(params=parameters,
                                lr=optimizer_cfg.lr,
                                weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'sgd':
        optimizer = optim.SGD(params=parameters,
                              lr=optimizer_cfg.lr,
                              momentum=optimizer_cfg.momentum,
                              weight_decay=optimizer_cfg.weight_decay)
    elif optimizer_cfg.type == 'RMS':
        optimizer = optim.RMSprop(params=parameters,
                                  lr=optimizer_cfg.lr,
                                  weight_decay=optimizer_cfg.weight_decay)
    else:
        raise Exception('没有该优化器!')

    if not lr_scheduler_cfg:
        lr_scheduler = None
    elif lr_scheduler_cfg.policy == 'cos':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            lr_scheduler_cfg.T_0,
            lr_scheduler_cfg.T_mult,
            lr_scheduler_cfg.eta_min,
            last_epoch=lr_scheduler_cfg.last_epoch)
    elif lr_scheduler_cfg.policy == 'LambdaLR':
        import math
        lf = lambda x: (((1 + math.cos(x * math.pi / train_cfg.num_epochs)) / 2
                         )**1.0) * 0.95 + 0.05  # cosine
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                         lr_lambda=lf)
        lr_scheduler.last_epoch = 0
    else:
        lr_scheduler = None

    # 定义损失函数
    DiceLoss_fn = DiceLoss(mode='multiclass')
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    loss_func = L.JointLoss(first=DiceLoss_fn,
                            second=SoftCrossEntropy_fn,
                            first_weight=0.5,
                            second_weight=0.5).cuda()
    # loss_cls_func = torch.nn.BCEWithLogitsLoss()

    # 创建保存模型的文件夹
    check_point_dir = '/'.join(model_cfg.check_point_file.split('/')[:-1])
    if not os.path.exists(check_point_dir):  # 如果文件夹不存在就创建
        os.mkdir(check_point_dir)

    # 开始训练
    auto_save_epoch_list = train_cfg.setdefault(key='auto_save_epoch_list',
                                                default=5)  # 每隔几轮保存一次模型,默认为5
    train_loss_list = []
    val_loss_list = []
    val_loss_min = 999999
    best_epoch = 0
    best_miou = 0
    train_loss = 10  # 设置一个初始值
    logger.info('开始在{}上训练{}模型...'.format(device, model_cfg.type))
    logger.info('补充信息:{}\n'.format(cfg.setdefault(key='info', default='None')))
    for epoch in range(train_cfg.num_epochs):
        print()
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        start_time = time.time()
        print(f"正在进行第{epoch}轮训练...")
        logger.info('*' * 10 + f"第{epoch}轮" + '*' * 10)
        #
        # 训练一轮
        if train_cfg.is_swa:  # swa训练方式
            train_loss = train_epoch(swa_model, optimizer, lr_scheduler,
                                     loss_func, train_dataloader, epoch,
                                     device)
            moving_average(model, swa_model, 1.0 / (swa_n + 1))
            swa_n += 1
            bn_update(train_dataloader, model, device)
        else:
            train_loss = train_epoch(model, optimizer, lr_scheduler, loss_func,
                                     train_dataloader, epoch, device)
            # train_loss = train_unet3p_epoch(model, optimizer, lr_scheduler, loss_func, train_dataloader, epoch, device)

        #
        # 在训练集上评估模型
        # val_loss, val_miou = evaluate_unet3p_model(model, val_dataset, loss_func, device,
        #                                     cfg.num_classes, train_cfg.num_workers, batch_size=train_cfg.batch_size)
        if not is_online_train:  # 只有在线下训练的时候才需要评估模型
            val_loss, val_miou = evaluate_model(model, val_dataloader,
                                                loss_func, device,
                                                cfg.num_classes)
        else:
            val_loss = 0
            val_miou = 0

        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)

        # 保存模型
        if not is_online_train:  # 非线上训练时需要保存best model
            if val_loss < val_loss_min:
                val_loss_min = val_loss
                best_epoch = epoch
                best_miou = val_miou
                if is_parallel:
                    torch.save(model.module, model_cfg.check_point_file)
                else:
                    torch.save(model, model_cfg.check_point_file)

        if epoch in auto_save_epoch_list:  # 如果再需要保存的轮次中,则保存
            model_file = model_cfg.check_point_file.split(
                '.pth')[0] + '-epoch{}.pth'.format(epoch)
            if is_parallel:
                torch.save(model.module, model_file)
            else:
                torch.save(model, model_file)

        # 打印中间结果
        end_time = time.time()
        run_time = int(end_time - start_time)
        m, s = divmod(run_time, 60)
        time_str = "{:02d}分{:02d}秒".format(m, s)
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        out_str = "第{}轮训练完成,耗时{},\t训练集上的loss={:.6f};\t验证集上的loss={:.4f},mIoU={:.6f}\t最好的结果是第{}轮,mIoU={:.6f}" \
            .format(epoch, time_str, train_loss, val_loss, val_miou, best_epoch, best_miou)
        # out_str = "第{}轮训练完成,耗时{},\n训练集上的segm_loss={:.6f},cls_loss{:.6f}\n验证集上的segm_loss={:.4f},cls_loss={:.4f},mIoU={:.6f}\n最好的结果是第{}轮,mIoU={:.6f}" \
        #     .format(epoch, time_str, train_loss, train_cls_loss, val_loss, val_cls_loss, val_miou, best_epoch,
        #             best_miou)
        print(out_str)
        logger.info(out_str + '\n')
Esempio n. 17
0
 def test_dice_negatives(self):
     dice_loss = DiceLoss()
     x = torch.FloatTensor([[1., 1.], [1., 0.]])
     y = torch.FloatTensor([[0., 1.], [1., 0.]])
     dice = dice_loss(x, y)
     self.assertEqual(round(dice.item(), 2), 0.2)