コード例 #1
0
def main():
    #############
    # init args #
    #############
    train_T2_path = '/home/Multi_Modality/data/fold4/train/T2_aug'
    train_target_path = '/home/Multi_Modality/data/fold4/train/label_aug'
    train_DWI_path = '/home/Multi_Modality/data/fold4/train/DWI_aug'

    test_T2_path = '/home/Multi_Modality/data/fold4/test/test_T2'
    test_target_path = '/home/Multi_Modality/data/fold4/test/test_label'
    test_DWI_path = '/home/Multi_Modality/data/fold4/test/test_DWI'

    args = get_args()

    best_prec1 = 0.
    best_prec2 = 0.
    best_prec3 = 0.

    args.cuda = torch.cuda.is_available()
    if args.inference == '':
        args.save = args.save or 'work/AMRSegNet_fold4_SE.{}'.format(datestr())
    else:
        args.save = args.save or 'work/AMRSegNet_fold4_SE_inference.{}'.format(
            datestr())

    weight_decay = args.weight_decay
    setproctitle.setproctitle(args.save)

    torch.manual_seed(1)
    if args.cuda:
        torch.cuda.manual_seed(1)

    if args.inference == '':
        # writer for tensorboard
        if args.save and args.inference == '':
            idx = args.save.rfind('/')
            log_dir = 'runs' + args.save[idx:]
            print('log_dir', log_dir)
            writer = SummaryWriter(log_dir)
        else:
            writer = SummaryWriter()
    else:
        idx = args.save.rfind('/')
        log_dir = 'runs' + args.save[idx:]
        print('log_dir', log_dir)
        writer = SummaryWriter(log_dir)

    #########################
    # building  AMRSegNet   #
    #########################
    print("building AMRSegNet-----")
    batch_size = args.ngpu * args.batchSz
    # model = unet.UNet(relu=False)
    model = AMRSegNet_noalpha.AMRSegNet()

    x = torch.zeros((1, 1, 256, 256))
    writer.add_graph(model, (x, x))

    if args.cuda:
        model = model.cuda()

    model = nn.parallel.DataParallel(model, list(range(args.ngpu)))

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        model.apply(weights_init)

    print('Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # if args.cuda:
    #     model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save, exist_ok=True)

    # define a logger and write information
    logger = Logger(os.path.join(args.save, 'log.txt'))
    logger.print3('batch size is %d' % args.batchSz)
    logger.print3('nums of gpu is %d' % args.ngpu)
    logger.print3('num of epochs is %d' % args.nEpochs)
    logger.print3('start-epoch is %d' % args.start_epoch)
    logger.print3('weight-decay is %e' % args.weight_decay)
    logger.print3('optimizer is %s' % args.opt)

    ################
    # prepare data #
    ################
    # train_transform = transforms.Compose([RandomHorizontalFlip(p=0.7),
    #                                       RandomRotation(30),
    #                                       Crop(),
    #                                       ToTensor(),
    #                                       Normalize(0.5, 0.5)])
    train_transform = transforms.Compose(
        [Crop(), ToTensor(), Normalize(0.5, 0.5)])
    # train_transform = transforms.Compose([Crop(), ToTensor(), Normalize(0.5, 0.5)])
    test_transform = transforms.Compose(
        [Crop(), ToTensor(), Normalize(0.5, 0.5)])

    # inference dataset
    if args.inference != '':
        if not args.resume:
            print("args.resume must be set to do inference")
            exit(1)
        kwargs = {'num_workers': 0} if args.cuda else {}
        T2_src = args.inference
        DWI_src = args.dwiinference
        tar = args.target

        inference_batch_size = 1
        dataset = Lung_dataset(image_path=T2_src,
                               image2_path=DWI_src,
                               target_path=tar,
                               transform=test_transform)
        loader = DataLoader(dataset,
                            batch_size=inference_batch_size,
                            shuffle=False,
                            **kwargs)
        inference(args, loader, model)

        return

    # tarin dataset
    print("loading train set --- ")
    kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
    train_set = Lung_dataset(image_path=train_T2_path,
                             image2_path=train_DWI_path,
                             target_path=train_target_path,
                             transform=train_transform)
    test_set = Lung_dataset(image_path=test_T2_path,
                            image2_path=test_DWI_path,
                            target_path=test_target_path,
                            transform=test_transform,
                            mode='test')

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              **kwargs)
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)

    # class_weights
    target_mean = train_set.get_target_mean()
    bg_weight = target_mean / (1. + target_mean)
    fg_weight = 1. - bg_weight
    class_weights = torch.FloatTensor([bg_weight, fg_weight])
    if args.cuda:
        class_weights = class_weights.cuda()

    #############
    # optimizer #
    #############
    lr = 0.7 * 1e-2
    if args.opt == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=1e-1,
                              momentum=0.99,
                              weight_decay=weight_decay)
    elif args.opt == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  weight_decay=weight_decay)

    # loss function
    loss_fn = {}
    loss_fn['surface_loss'] = SurfaceLoss()
    loss_fn['ti_loss'] = TILoss()
    loss_fn['dice_loss'] = DiceLoss()
    loss_fn['l1_loss'] = nn.L1Loss()
    loss_fn['CELoss'] = nn.CrossEntropyLoss()

    ############
    # training #
    ############
    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    err_best = 0.

    for epoch in range(1, args.nEpochs + 1):
        # adjust_opt(args.opt, optimizer, epoch)
        if epoch > 20:
            lr = 1e-3
        if epoch > 30:
            lr = 1e-4
        if epoch > 50:
            lr = 1e-5
        # if epoch > 40:
        #     lr = 1e-5
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        mean_loss = train(args, epoch, model, train_loader, optimizer, trainF,
                          loss_fn, writer)
        dice, recall, precision = test(args, epoch, model, test_loader,
                                       optimizer, testF, loss_fn, logger,
                                       writer)
        writer.add_scalar('fold4_train_loss/epoch', mean_loss, epoch)

        is_best1, is_best2, is_best3 = False, False, False
        if dice > best_prec1:
            is_best1 = True
            best_prec1 = dice
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            }, is_best1, args.save, "AMRSegNet_dice")

    trainF.close()
    testF.close()

    writer.close()
コード例 #2
0
ファイル: train.py プロジェクト: vigorwei/Segmentation
def train(model, device, args, num_fold=0):
    dataset_train = myDataset(args.data_root,
                              args.target_root,
                              args.crop_size,
                              "train",
                              k_fold=args.k_fold,
                              imagefile_csv=args.dataset_file_list,
                              num_fold=num_fold)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)
    num_train_data = len(dataset_train)  # 训练数据大小
    dataset_val = myDataset(args.data_root,
                            args.target_root,
                            args.crop_size,
                            "val",
                            k_fold=args.k_fold,
                            imagefile_csv=args.dataset_file_list,
                            num_fold=num_fold)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers,
                                pin_memory=True,
                                drop_last=True)
    num_train_val = len(dataset_val)  # 验证数据大小
    ####################
    writer = SummaryWriter(log_dir=args.log_dir[num_fold], comment=f'tb_log')

    if args.optim == "SGD":
        opt = torch.optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    else:
        opt = torch.optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    # 定义损失函数
    if args.OHEM:
        criterion = OhemCrossEntropy(thres=0.8, min_kept=10000)
    else:
        criterion = nn.CrossEntropyLoss(
            torch.tensor(args.class_weight, device=device))
    criterion_dice = DiceLoss()

    cp_manager = utils.save_checkpoint_manager(3)
    step = 0
    for epoch in range(args.num_epochs):
        model.train()
        lr = utils.poly_learning_rate(args, opt, epoch)  # 学习率调节

        with tqdm(
                total=num_train_data,
                desc=
                f'[Train] fold[{num_fold}/{args.k_fold}] Epoch[{epoch + 1}/{args.num_epochs} LR{lr:.8f}] ',
                unit='img') as pbar:
            for batch in dataloader_train:
                step += 1
                # 读取训练数据
                image = batch["image"]
                label = batch["label"]
                assert len(image.size()) == 4
                assert len(label.size()) == 3
                image = image.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.long)

                # 前向传播
                opt.zero_grad()
                outputs = model(image)
                main_out = outputs["main_out"]

                # 计算损失
                diceloss = criterion_dice(main_out, label)
                celoss = criterion(main_out, label)
                totall_loss = celoss + diceloss * args.dice_weight
                if "sim_loss" in outputs.keys():
                    totall_loss += outputs["sim_loss"] * 0.2
                if "aux_out" in outputs.keys():  # 计算辅助损失函数
                    aux_losses = 0
                    for aux_p in outputs["aux_out"]:
                        auxloss = (criterion(aux_p, label) + criterion_dice(
                            aux_p, label) * args.dice_weight) * args.aux_weight
                        totall_loss += auxloss
                        aux_losses += auxloss

                if "mu" in outputs.keys():  # EMAU 的基更新
                    with torch.no_grad():
                        mu = outputs["mu"]
                        mu = mu.mean(dim=0, keepdim=True)
                        momentum = 0.9
                        # model.emau.mu *= momentum
                        # model.emau.mu += mu * (1 - momentum)
                        model.effcient_module.em.mu *= momentum
                        model.effcient_module.em.mu += mu * (1 - momentum)
                if "mu1" in outputs.keys():
                    with torch.no_grad():
                        mu1 = outputs['mu1'].mean(dim=0, keepdim=True)
                        model.donv_up1.em.mu = model.donv_up1.em.mu * 0.9 + mu1 * (
                            1 - 0.9)

                        mu2 = outputs['mu2'].mean(dim=0, keepdim=True)
                        model.donv_up2.em.mu = model.donv_up2.em.mu * 0.9 + mu2 * (
                            1 - 0.9)

                        mu3 = outputs['mu3'].mean(dim=0, keepdim=True)
                        model.donv_up3.em.mu = model.donv_up3.em.mu * 0.9 + mu3 * (
                            1 - 0.9)

                        mu4 = outputs['mu4'].mean(dim=0, keepdim=True)
                        model.donv_up4.em.mu = model.donv_up4.em.mu * 0.9 + mu4 * (
                            1 - 0.9)
                totall_loss.backward()
                opt.step()

                if step % 5 == 0:
                    writer.add_scalar("Train/CE_loss", celoss.item(), step)
                    writer.add_scalar("Train/Dice_loss", diceloss.item(), step)
                    if args.aux:
                        writer.add_scalar("Train/aux_losses", aux_losses, step)
                    if "sim_loss" in outputs.keys():
                        writer.add_scalar("Train/sim_loss",
                                          outputs["sim_loss"], step)
                    writer.add_scalar("Train/Totall_loss", totall_loss.item(),
                                      step)

                pbar.set_postfix(**{'loss': totall_loss.item()})  # 显示loss
                pbar.update(image.size()[0])

        if (epoch + 1) % args.val_step == 0:
            # 验证
            mDice, mIoU, mAcc, mSensitivity, mSpecificity = val(
                model, dataloader_val, num_train_val, device, args)
            writer.add_scalar("Valid/Dice_val", mDice, step)
            writer.add_scalar("Valid/IoU_val", mIoU, step)
            writer.add_scalar("Valid/Acc_val", mAcc, step)
            writer.add_scalar("Valid/Sen_val", mSensitivity, step)
            writer.add_scalar("Valid/Spe_val", mSpecificity, step)
            # 写入csv文件
            val_result = [
                num_fold, epoch + 1, mDice, mIoU, mAcc, mSensitivity,
                mSpecificity
            ]
            with open(args.val_result_file, "a") as f:
                w = csv.writer(f)
                w.writerow(val_result)
            # 保存模型
            cp_manager.save(
                model,
                os.path.join(args.checkpoint_dir[num_fold],
                             f'CP_epoch{epoch + 1}_{float(mDice):.4f}.pth'),
                float(mDice))
            if (epoch + 1) == (args.num_epochs):
                torch.save(
                    model.state_dict(),
                    os.path.join(
                        args.checkpoint_dir[num_fold],
                        f'CP_epoch{epoch + 1}_{float(mDice):.4f}.pth'))
コード例 #3
0
def inference(args, loader, model):
    src = args.inference
    model.eval()
    dice_list = []
    mean_precision = []
    mean_recall = []
    mean_hausdorff = []
    mean_jaccard = []

    with torch.no_grad():
        for num, sample in enumerate(loader):
            data, data2, target = sample['image'], sample['image_b'], sample[
                'target']
            if args.cuda:
                data, data2, target = data.cuda(), data2.cuda(), target.cuda()
            data, data2, target = Variable(data), Variable(data2), Variable(
                target)

            output = model(data, data2)

            loss, jaccard = DiceLoss.dice_coeficient(output, target)
            precision, recall = confusion(output, target)
            hausdorff_distance = compute_hausdorff(output.cpu().numpy(),
                                                   target.cpu().numpy())
            # dice = loss.cpu().numpy().astype(np.float32)
            dice = loss.cpu().numpy()
            dice_list.append(dice)
            mean_precision.append(precision.item())
            mean_recall.append(recall.item())
            mean_hausdorff.append(hausdorff_distance)
            mean_jaccard.append(jaccard.item())

            data = (data * 0.5 + 0.5)
            data2 = (data2 * 0.5 + 0.5)
            img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2,
                                                                      0)[:, :,
                                                                         0]
            img2 = make_grid(data2,
                             padding=20).cpu().numpy().transpose(1, 2, 0)[:, :,
                                                                          0]
            target = target.view(data.shape)
            target = target.float()
            gt = make_grid(target,
                           padding=20).cpu().numpy().transpose(1, 2, 0)[:, :,
                                                                        0]
            # _, pre = output_softmax.max(1)
            pre = output > 0.5
            pre = pre.float()
            pre = pre.view(data.shape)
            pre = make_grid(pre, padding=20).cpu().numpy().transpose(1, 2,
                                                                     0)[:, :,
                                                                        0]

            gt_img = label2rgb(gt, img, bg_label=0)
            pre_img = label2rgb(pre, img, bg_label=0)
            gt_img2 = label2rgb(gt, img2, bg_label=0)

            fig = plt.figure()
            ax = fig.add_subplot(231)
            ax.imshow(gt_img)
            ax.set_title('T2 ground truth')
            ax.axis('off')
            ax = fig.add_subplot(233)
            ax.imshow(pre_img)
            ax.set_title('prediction')
            ax.axis('off')
            ax = fig.add_subplot(232)
            ax.imshow(gt_img2)
            ax.set_title('DWI ground truth')
            ax.axis('off')
            ax = fig.add_subplot(234)
            ax.imshow(img)
            ax.set_title('T2 image')
            ax.axis('off')
            ax = fig.add_subplot(235)
            ax.imshow(img2)
            ax.set_title('DWI image')
            ax.axis('off')
            fig.tight_layout()
            fig.savefig(
                '/home/Multi_Modality/data/fold5/inference/AMRSegNet_noalpha/%d_%4f.png'
                % (num, dice))

            print('processing {}/{}\r dice:{}'.format(num, len(loader.dataset),
                                                      dice))

        mean_jaccard = np.array(mean_jaccard).mean()
        mean_dice = np.array(dice_list).mean()
        std_dice = np.std(np.array(dice_list))
        mean_recall = np.mean(mean_recall)
        mean_precision = np.mean(mean_precision)
        mean_hausdorff = np.mean(mean_hausdorff)
        F1_score = 2 * mean_recall * mean_precision / (mean_recall +
                                                       mean_precision)

        print('mean_jaccard: %4f' % mean_jaccard)
        print('mean_dice: %4f' % mean_dice)
        print('std: %4f' % std_dice)
        print('mean_recall: %4f' % mean_recall)
        print('mean_precision: %4f' % mean_precision)
        print('F1_score: ', F1_score)
        print('mean_hausdorff: %4f' % mean_hausdorff)
コード例 #4
0
def test(args, epoch, model, test_loader, optimizer, testF, loss_fn, logger,
         writer):
    model.eval()
    mean_dice = []
    mean_jaccard = []
    mean_precision = []
    mean_recall = []
    mean_hausdorff = []

    with torch.no_grad():
        for sample in test_loader:
            data, data2, target = sample['image'], sample['image_b'], sample[
                'target']
            if args.cuda:
                data, data2, target = data.cuda(), data2.cuda(), target.cuda()
            data, data2, target = Variable(data), Variable(data2), Variable(
                target, requires_grad=False)

            output = model(data, data2)

            # target = target.view(target.numel())
            # loss = loss_fn['dice_loss'](output, target[:,:,7:-7,7:-7])
            # dice = 1 - loss
            # m = nn.Softmax(dim=1)
            # output = m(output)

            # pdb.set_trace()
            # Hausdorff Distance
            hausdorff_distance = compute_hausdorff(output.cpu().numpy(),
                                                   target.cpu().numpy())
            # Dice coefficient
            dice, jaccard = DiceLoss.dice_coeficient(output, target)
            precision, recall = confusion(output, target)

            mean_precision.append(precision.item())
            mean_recall.append(recall.item())
            mean_dice.append(dice.item())
            mean_jaccard.append(jaccard.item())
            mean_hausdorff.append(hausdorff_distance)

        # show the last sample
        shape = [data.shape[0], 1, data.shape[2], data.shape[3]]
        if epoch % 1 == 0:
            # img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2, 0)[0]
            data = (data * 0.5 + 0.5)
            data2 = (data2 * 0.5 + 0.5)
            img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2,
                                                                      0)[:, :,
                                                                         0]
            img2 = make_grid(data2,
                             padding=20).cpu().numpy().transpose(1, 2, 0)[:, :,
                                                                          0]
            # print('img.shape', img.shape)
            target = target.view(shape)
            target = target.float()
            gt = make_grid(target,
                           padding=20).cpu().numpy().transpose(1, 2, 0)[:, :,
                                                                        0]
            # _, pre = output_softmax.max(1)
            pre = output > 0.5
            pre = pre.float()
            pre = pre.view(shape)
            pre = make_grid(pre, padding=20).cpu().numpy().transpose(1, 2,
                                                                     0)[:, :,
                                                                        0]

            gt_img = label2rgb(gt, img, bg_label=0)
            pre_img = label2rgb(pre, img, bg_label=0)
            gt_img2 = label2rgb(gt, img2, bg_label=0)

            fig = plt.figure()
            ax = fig.add_subplot(311)
            ax.imshow(gt_img)
            ax.set_title('T2 ground truth')
            ax = fig.add_subplot(312)
            ax.imshow(pre_img)
            ax.set_title('prediction')
            ax = fig.add_subplot(313)
            ax.imshow(gt_img2)
            ax.set_title('DWI ground truth')
            fig.tight_layout()

            writer.add_figure('test result', fig, epoch)
            fig.clear()

        writer.add_scalar('fold4_test_dice/epoch', np.mean(mean_dice), epoch)
        writer.add_scalar('fold4_test_jaccard/epoch', np.mean(mean_jaccard),
                          epoch)
        writer.add_scalar('fold4_test_precisin/epoch', np.mean(mean_precision),
                          epoch)
        writer.add_scalar('fold4_test_recall/epoch', np.mean(mean_recall),
                          epoch)
        writer.add_scalar('fold4_hausdorff_distance/epoch',
                          np.mean(mean_hausdorff), epoch)

        print('test mean_dice: ', np.mean(mean_dice))
        print('test mean jaccard: ', np.mean(mean_jaccard))
        print('mean_dice_length ', len(mean_dice))
        testF.write('{},{},{},{}\n'.format(epoch, np.mean(mean_dice),
                                           np.mean(mean_precision),
                                           np.mean(mean_recall)))
        testF.flush()
        return np.mean(mean_dice), np.mean(mean_recall), np.mean(
            mean_precision)
コード例 #5
0
def train(args, epoch, model, train_loader, optimizer, trainF, loss_fn,
          writer):
    model.train()
    nProcessed = 0
    nTrain = len(train_loader.dataset)
    loss_list = []
    print('--------------------Epoch{}------------------------'.format(epoch))
    for batch_idx, sample in enumerate(train_loader):
        # read data
        data, data2, target = sample['image'], sample['image_b'], sample[
            'target']
        # pdb.set_trace()
        if args.cuda:
            data, data2, target = data.cuda(), data2.cuda(), target.cuda()
        data, data2, target = Variable(data), Variable(data2), Variable(
            target, requires_grad=False)

        # print('data.shape: ', data.shape)
        # print('data2.shape: ', data2.shape)
        # feed to model
        output = model(data, data2)
        target = target.view(output.shape[0],
                             target.numel() // output.shape[0])

        # loss
        loss = loss_fn['dice_loss'](output, target)

        target = target.long()

        dice, jaccard = DiceLoss.dice_coeficient(output > 0.5, target)
        precision, recall = confusion(output > 0.5, target)

        # back propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # show some result on tensorboard
        nProcessed += len(data)
        partialEpoch = epoch + batch_idx / len(train_loader) - 1
        print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
            partialEpoch, nProcessed, nTrain,
            100. * batch_idx / len(train_loader), loss.item()))
        print('Jaccard index: %6f, soft dice: %6f' % (jaccard, dice))

        # writer.add_scalar('train_loss/epoch', loss, partialEpoch)
        trainF.write('{},{},{}\n'.format(partialEpoch, loss.item(),
                                         loss.item()))
        trainF.flush()
        # show images on tensorboard
        with torch.no_grad():
            shape = [data.shape[0], 1, data.shape[2], data.shape[3]]
            if batch_idx % 4 == 0:
                # img = make_grid(data, padding=20).cpu().numpy().transpose(1, 2, 0)[0]
                # pdb.set_trace()
                data = (data * 0.5 + 0.5)
                data2 = (data2 * 0.5 + 0.5)

                img = make_grid(data,
                                padding=20).cpu().detach().numpy().transpose(
                                    1, 2, 0)[:, :, 0]
                img2 = make_grid(data2,
                                 padding=20).cpu().detach().numpy().transpose(
                                     1, 2, 0)[:, :, 0]
                # print('img.shape', img.shape)
                target = target.view(shape)
                target = target.float()
                gt = make_grid(target,
                               padding=20).cpu().detach().numpy().transpose(
                                   1, 2, 0)[:, :, 0]
                # _, pre = output_softmax.max(1)

                pre = output > 0.5
                pre = pre.float()
                # pdb.set_trace()
                pre = pre.view(shape)
                pre = make_grid(pre,
                                padding=20).cpu().numpy().transpose(1, 2,
                                                                    0)[:, :, 0]
                # pdb.set_trace()
                gt_img = label2rgb(gt, img, bg_label=0)
                pre_img = label2rgb(pre, img, bg_label=0)
                gt_img2 = label2rgb(gt, img2, bg_label=0)
                # pdb.set_trace()

                fig = plt.figure()
                ax = fig.add_subplot(311)
                ax.imshow(gt_img)
                ax.set_title('T2 ground truth')
                ax = fig.add_subplot(312)
                ax.imshow(pre_img)
                ax.set_title('prediction')
                ax = fig.add_subplot(313)
                ax.imshow(gt_img2)
                ax.set_title('DWI ground truth')
                fig.tight_layout()

                writer.add_figure('train result', fig, epoch)
                fig.clear()
    loss_list.append(loss.item())

    return np.mean(loss_list)
コード例 #6
0
def Train_Val(epoches, net, train_data,val_data):    
    net = net.train()
    net = net.cuda()
    loss1 = nn.BCEWithLogitsLoss().cuda()
    loss2 = DiceLoss().cuda()
    Sum_Train_miou = 0
    Sum_Val_miou=0
    for e in range(epoches):
        #train_loss = 0
        train_mean_iou = 0
        j = 0
        process = tqdm(train_data)
        losses = []
        
        for data in process:
            j+=1
            with torch.no_grad():
                im = Variable(data[0].cuda())
                label = Variable(data[1].cuda())  #lable_onehot
                #label1 = Variable(data[2].cuda())
            #print("im.shape:",im.shape) #torch.Size([2, 3, 256, 768])
            #print("label.shape:",label.shape) #torch.Size([2, 8, 256, 768])
            out = net(im)
            #out_softmax=F.log_softmax(out, dim=1) 
            sig = torch.sigmoid(out)

            loss = loss1(out,label)+loss2(sig,label)
            losses.append(loss.item())

            #backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Update learning rate
            process.set_postfix_str(f"loss {np.mean(losses)}")
            pred = torch.argmax(F.softmax(out, dim=1), dim=1)
            mask = torch.argmax(F.softmax(label, dim=1), dim=1)
            #print("pred.shape:",pred.shape)#torch.Size([2, 256, 768])
            #print("mask.shape:",mask.shape) #  torch.Size([2, 256, 768])         
            result = compute_iou(pred, mask)
            if j % 200 == 0:
                tmiou =[]
                TP_all=0
                TA_all=0
                for i in range(1, 8):

                    if result["TA"][i] !=0: 
                        t_miou_i=result["TP"][i] / result["TA"][i]
                        result_string = "{}: {:.4f} \n".format(i, t_miou_i)
                        print(result_string)
                        tmiou.append(t_miou_i)     
                #tmiou = tmiou / 7
                t_miou=np.mean(tmiou)  
                print("train_mean_iou:",t_miou)
                
                TP_sum=[]
                TA_sum=[]
                for i,j in result["TP"].items():
                    TP_sum.append(j)
                for i,j in result["TA"].items():
                    TA_sum.append(j)
                TP_sum=np.array(TP_sum)
                TA_sum=np.array(TA_sum)
                TP_sum=TP_sum[1:].sum()
                TA_sum=TA_sum[1:].sum()
                print("acc:",'%.5f' %(TP_sum/TA_sum))
                
            if j % 500 == 0:
                torch.save(net.state_dict(), 'deeplabv3p_baidulane.pth')

        torch.save(net.state_dict(), 'deeplabv3p_baidulane.pth')
        
        j=0
        #net.load_state_dict(torch.load('./deeplabv3p_baidulane.pth'))
        #net=net.cuda()
        process = tqdm(val_data)
        losses = []
        result = {
            "TP": {i: 0
                   for i in range(8)},
            "TA": {i: 0
                   for i in range(8)}
        }

        net = net.eval()
        val_mean_iou = 0        
        for data in process:
            j+=1
            with torch.no_grad():
              im = Variable(data[0].cuda())
              label = Variable(data[1].cuda())
              #label_1 = Variable(data[2].cuda())
            # forward
            out = net(im)
            sig = torch.sigmoid(out)
            loss = loss1(out,label)+loss2(sig,label)
            losses.append(loss.item())
            
            pred = torch.argmax(F.softmax(out, dim=1), dim=1)
            mask = torch.argmax(F.softmax(label, dim=1), dim=1)
            result = compute_iou(pred, mask)
            process.set_postfix_str(f"loss {np.mean(losses)}")

            if j % 200 == 0:
                vmiou = []
                for i in range(1, 8):
                    if result["TA"][i] !=0: 
                        v_miou_i=result["TP"][i] / result["TA"][i]
                        result_string = "{}: {:.4f} \n".format(i, v_miou_i)
                        print(result_string)
                        vmiou.append(v_miou_i)  
                v_miou=np.mean(vmiou)
                print("val_mean_iou:",v_miou)
                TP_sum=[]
                TA_sum=[]
                for i,j in result["TP"].items():
                    TP_sum.append(j)
                for i,j in result["TA"].items():
                    TA_sum.append(j)
                TP_sum=np.array(TP_sum)
                TA_sum=np.array(TA_sum)
                TP_sum=TP_sum[1:].sum()
                TA_sum=TA_sum[1:].sum()
                print("acc:",'%.5f' %(TP_sum/TA_sum))
                

        epoch_str = ('Epoch: {},  Train Mean IoU: {:.5f},  Valid Mean IU: {:.5f} '.format(e, t_miou,v_miou))    
        print(epoch_str)  
コード例 #7
0
def train_general(args):
    args.optimizer = 'Adam'
    args.n_classes = 2
    args.batch_size = 8
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    # print(args.model_name)
    # print(args.test)
    if args.model_name == 'FCNet':
        model = FCNet(args).cuda()
        model = torch.nn.DataParallel(model)
        if args.optimizer == 'SGD':
            optimizer = SGD(model.parameters(),
                            .1,
                            weight_decay=5e-4,
                            momentum=.99)
        elif args.optimizer == 'Adam':
            optimizer = Adam(model.parameters(), .1, weight_decay=5e-4)
        criterion = cross_entropy2d
        scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1)
    elif args.model_name == 'CENet':
        model = CE_Net_(args).cuda()
        model = torch.nn.DataParallel(model)
        if args.optimizer == 'SGD':
            optimizer = SGD(model.parameters(),
                            .1,
                            weight_decay=5e-4,
                            momentum=.99)
            scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1)
        elif args.optimizer == 'Adam':
            optimizer = Adam(model.parameters(), .001, weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer, [400, 3200], .1)
        # criterion = cross_entropy2d
        criterion = DiceLoss()
        # scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1)
    start_iter = 0
    if args.model_path is not None:
        if os.path.isfile(args.model_path):
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
        else:
            print('Unable to load {}'.format(args.model_name))

    train_loader, valid_loader = get_loaders(args)

    try:
        os.mkdir('logs/')
    except:
        pass
    try:
        os.mkdir('results/')
    except:
        pass
    try:
        os.mkdir('results/' + args.model_name)
    except:
        pass
    writer = SummaryWriter(log_dir='logs/')

    best = -100.0
    i = start_iter
    flag = True

    running_metrics_val = Acc_Meter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    # while i <= args.niter and flag:
    while i <= 300000 and flag:
        for (images, labels) in train_loader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.cuda()
            labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            # if (i + 1) % cfg["training"]["print_interval"] == 0:
            if (i + 1) % 50 == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    300000,
                    loss.item(),
                    time_meter.avg / args.batch_size,
                )

                print(print_str)
                # logger.info(print_str)
                # writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                # time_meter.reset()

            # if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"]["train_iters"]:
            if (i + 1) % 500 == 0 or (i + 1) == 300000:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valid_loader)):
                        images_val = images_val.cuda()  # to(device)
                        labels_val = labels_val.cuda()  # to(device)

                        outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)
                        val_loss = criterion(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                print("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                results = running_metrics_val.get_acc()
                for k, v in results.items():
                    writer.add_scalar(k, v, i + 1)
                print(results)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if results['cls_acc'] >= best:
                    best = results['cls_acc']
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best": best,
                    }
                    save_path = os.path.join(
                        "results/{}/results_{}_best_model.pkl".format(
                            args.model_name, i + 1), )
                    torch.save(state, save_path)

            if (i + 1) == 300000:
                flag = False
                break
    writer.close()