Пример #1
0
def _epoch_valid(net, loss_func, data, n_class, device, i_epoch):
    """
    一个epoch验证
    :param net: AI网络
    :param loss_func: loss function
    :param data: valid data set
    :param n_class: n种分类
    :param device: torch.device CPU or GPU
    :return: loss, miou
    """
    net.to(device)
    net.eval()  # 验证

    total_loss = 0.  # 一个epoch验证的loss
    total_cm = np.zeros((n_class, n_class))  # ndarray
    total_batch_miou = 0.

    with torch.no_grad():  # 验证阶段,不需要计算梯度,节省内存
        bar_format = '{desc}{postfix}|{n_fmt}/{total_fmt}|{percentage:3.0f}%|{bar}|{elapsed}<{remaining}'
        # {desc}{进度条百分比}[{当前/总数}{用时<剩余时间}{自己指定的后面显示的}]
        tqdm_data = tqdm(data,
                         ncols=120,  # 进度条宽120列,linux必须指定,否则按照terminal宽度80
                         bar_format=bar_format,  # 进度条格式
                         desc='Epoch {:02d} Valid'.format(i_epoch))  # 进度条的{desc}
        for i_batch, (im, lb) in enumerate(tqdm_data, start=1):
            im = im.to(device)  # [N,C,H,W] tensor 一个验证batch image
            lb = lb.to(device)  # [N,H,W] tensor 一个验证batch label

            output = net(im)  # [N,C,H,W] tensor 前向传播,计算一个验证batch的output
            loss = loss_func(output, lb.type(torch.long))  # 计算一个验证batch的loss
            batch_loss = loss.detach().item()  # detach还是加上吧,
            total_loss += batch_loss  # 累加验证batch的loss

            # 验证的时候不进行反向传播
            pred = torch.argmax(F.softmax(output, dim=1), dim=1)  # [N,H,W] tensor 将输出转化为dense prediction
            batch_cm = get_confusion_matrix(pred.cpu().numpy(),
                                            lb.cpu().numpy(),
                                            n_class)  # 计算混淆矩阵并累加
            total_cm += batch_cm
            batch_miou = get_metrics(batch_cm, metrics='mean_iou')
            total_batch_miou += batch_miou

            tqdm_str = 'Loss={:.4f}|mIoU={:.4f}|bat_mIoU={:.4f}'  # 进度条
            tqdm_data.set_postfix_str(
                tqdm_str.format(total_loss / i_batch,
                                get_metrics(total_cm, metrics='mean_iou'),
                                total_batch_miou / i_batch))
            pass
        total_loss /= len(data)  # 求取一个epoch验证的loss
        mean_iou = get_metrics(total_cm, metrics='mean_iou')  # float 求mIoU
        total_batch_miou /= len(data)

        # 记录Valid日志
        log_str = ('Valid Loss: {:.4f}|'
                   'Valid mIoU: {:.4f}|'
                   'Valid bat_mIoU: {:.4f}')
        log_str = log_str.format(total_loss, mean_iou, total_batch_miou)
        get_logger().info(log_str)
        return total_loss, mean_iou, total_batch_miou
Пример #2
0
def test_model(model, scanname, config):
    """ 
    metrics = test_model(model, scanname, config)
    
    Gets DICE scores for all organs of given test study 
    """
    tqdm.write("Testing model on " + scanname)
    time.sleep(0.5)  # just for tqdm
    vol_gt = datafunctions.load_GT_volume(config['dataset'], scanname)
    vol_segmented = segmentStudy(model, scanname, config)
    metrics = tools.get_metrics(vol_gt, vol_segmented, config)
    if (config["saveSegs"]):
        datafunctions.save_prediction(vol_segmented.astype(vol_gt.dtype),
                                      config,
                                      scanname,
                                      stage='segmentation')
    return metrics
Пример #3
0
def test(net, data, device, resize_to=256, n_class=8, compare=False):
    """
    测试
    :param net: AI网络
    :param data: test dataset
    :param device: torch.device GPU or CPU
    :param n_class: n种分类
    :param compare: 是否生成对比图片
    :return:
    """
    net.to(device)
    net.eval()  # 测试
    total_cm = np.zeros((n_class, n_class))  # 记录整个测试的混淆矩阵
    total_batch_miou = 0.  # 累加每张图像的mIoU

    offset = 690  # 剪裁690x3384
    pair_crop = PairCrop(offsets=(offset, None))  # 剪裁690x3384
    pair_resize = PairResize(size=resize_to)
    pair_norm_to_tensor = PairNormalizeToTensor(norm=True)  # 归一化并正则化

    with torch.no_grad():  # 测试阶段,不需要计算梯度,节省内存
        bar_format = '{desc}{postfix}|{n_fmt}/{total_fmt}|{percentage:3.0f}%|{bar}|{elapsed}<{remaining}'
        # {desc}{进度条百分比}[{当前/总数}{用时<剩余时间}{自己指定的后面显示的}]
        tqdm_data = tqdm(data, ncols=120, bar_format=bar_format, desc='Test')
        for i_batch, (im, lb) in enumerate(tqdm_data, start=1):
            # if i_batch > 1:
            #     break
            im_t, lb_t = pair_crop(im, lb)  # PIL Image,PIL Image
            im_t, lb_t = pair_resize(im_t, lb_t)  # PIL Image,PIL Image
            im_t, lb_t = pair_norm_to_tensor(im_t,
                                             lb_t)  # [C,H,W]tensor,[H,W]tensor

            im_t = im_t.to(device)  # [C,H,W]tensor装入GPU
            im_t = im_t.unsqueeze(0)  # 转换为[N,C,H,W] tensor
            output = net(im_t)  # 经过模型输出[N,C,H,W] tensor
            pred = torch.argmax(F.softmax(output, dim=1),
                                dim=1)  # [N,H,W] tensor

            pred = pred.unsqueeze(
                1)  # [N,C,H,W] tensor, F.interpolate操作图像需要[N,C,H,W] tensor
            pred = pred.type(
                torch.float
            )  # 转为float数,F.interpolate只对float类型操作,int,long等都没有实现
            pred = F.interpolate(pred,
                                 size=(lb.size[1] - offset, lb.size[0]),
                                 mode='nearest')  # pred用nearest差值
            pred = pred.type(torch.uint8)  # 再转回int类型
            pred = pred.squeeze(0).squeeze(0)  # [H,W]tensor
            pred = pred.cpu().numpy()  # [H,W]ndarray

            supplement = np.zeros((offset, lb.size[0]),
                                  dtype=np.uint8)  # [H,W]ndarray,补充成背景
            pred = np.append(
                supplement, pred,
                axis=0)  # 最终的估值,[H,W]ndarray,在H方向cat,给pred补充被剪裁的690x3384
            batch_cm = get_confusion_matrix(pred, lb, n_class)  # 本张图像的混淆矩阵
            total_cm += batch_cm  # 累加

            if compare:  # 生成对比图
                fontsize = 16  # 图像文字字体大小
                fig, ax = plt.subplots(2, 2, figsize=(20, 15))  # 画布
                ax = ax.flatten()

                ax[0].imshow(im)  # 左上角显示原图
                ax[0].set_title('Input Image', fontsize=fontsize)  # 标题

                ax[1].imshow(LaneSegDataset.decode_rgb(
                    np.asarray(lb)))  # 右上角显示 Grand Truth
                ax[1].set_title('Grand Truth', fontsize=fontsize)  # 标题

                batch_miou = get_metrics(batch_cm,
                                         metrics='mean_iou')  # 计算本张图像的mIoU
                fig.suptitle('mIoU:{:.4f}'.format(batch_miou),
                             fontsize=fontsize)  # 用mIoU作为大标题
                total_batch_miou += batch_miou

                mask = (pred != 0).astype(
                    np.uint8) * 255  # [H,W]ndarray,alpha融合的mask

                pred = LaneSegDataset.decode_rgb(pred)  # [H,W,C=3]ndarray RGB
                ax[3].imshow(pred)  # 右下角显示Pred
                ax[3].set_title('Pred', fontsize=fontsize)  # 标题

                mask = mask[..., np.newaxis]  # [H,W,C=1]ndarray
                pred = np.append(pred, mask,
                                 axis=2)  # [H,W,C=4]ndarray,RGB+alpha变为RGBA

                im = im.convert('RGBA')
                pred = Image.fromarray(pred).convert('RGBA')
                im_comp = Image.alpha_composite(im, pred)  # alpha融合
                ax[2].imshow(im_comp)  # 左下角显示融合图像
                ax[2].set_title('Pred over Input', fontsize=fontsize)  # 标题

                plt.subplots_adjust(left=0.01,
                                    bottom=0.01,
                                    right=0.99,
                                    top=0.99,
                                    wspace=0.01,
                                    hspace=0.01)  # 调整子图边距间距
                plt.savefig('/home/mist/imfolder/pred-{:s}.jpg'.format(
                    now_str()))  # 保存图像
                plt.close(fig)
                pass
            tqdm_str = 'mIoU={:.4f}|bat_mIoU={:.4f}'  # 进度条
            tqdm_data.set_postfix_str(
                tqdm_str.format(get_metrics(total_cm),
                                total_batch_miou / i_batch))
            pass
        mean_iou = get_metrics(total_cm)  # 整个测试的mIoU
        total_batch_miou /= len(data)

        logger = get_logger()
        msg = ('Test mIoU : {:.4f}|'
               'Test bat_mIoU : {:.4f}').format(mean_iou, total_batch_miou)
        logger.info(msg)
        return mean_iou
Пример #4
0
def _epoch_train(net, loss_func, optimizer, data, n_class, device, i_epoch):
    """
    一个epoch训练
    :param net: AI网络
    :param loss_func: loss function
    :param optimizer: optimizer
    :param data: train data set
    :param n_class: n种分类
    :param device: torch.device CPU or GPU
    :return: loss, miou
    """
    net.to(device)
    net.train()  # 训练

    total_loss = 0.  # 一个epoch训练的loss
    total_cm = np.zeros((n_class, n_class))  # ndarray 一个epoch的混淆矩阵
    total_batch_miou = 0.

    bar_format = '{desc}{postfix}|{n_fmt}/{total_fmt}|{percentage:3.0f}%|{bar}|{elapsed}<{remaining}'
    # {desc}{进度条百分比}[{当前/总数}{用时<剩余时间}{自己指定的后面显示的}]
    tqdm_data = tqdm(data,
                     ncols=120,  # 进度条宽120列,linux必须指定,否则按照terminal宽度80
                     bar_format=bar_format,  # 进度条格式
                     desc='Epoch {:02d} Train'.format(i_epoch))  # 进度条的{desc}
    for i_batch, (im, lb) in enumerate(tqdm_data, start=1):
        im = im.to(device)  # [N,C,H,W] tensor 一个训练batch image
        lb = lb.to(device)  # [N,H,W] tensor 一个训练batch label

        optimizer.zero_grad()  # 清空梯度

        output = net(im)  # [N,C,H,W] tensor 前向传播,计算一个训练batch的output

        loss = loss_func(output, lb.type(torch.long))  # 计算一个训练batch的loss
        batch_loss = loss.detach().item()  # train过程有gradient,必须detach才能读取
        total_loss += batch_loss  # 累加训练batch的loss

        loss.backward()  # 反向传播
        optimizer.step()  # 优化器迭代

        pred = torch.argmax(F.softmax(output, dim=1), dim=1)  # [N,H,W] tensor 将输出转化为dense prediction,减少一个C维度
        batch_cm = get_confusion_matrix(pred.cpu().numpy(),
                                        lb.cpu().numpy(),
                                        n_class)  # 计算混淆矩阵并累加
        total_cm += batch_cm
        batch_miou = get_metrics(batch_cm, metrics='mean_iou')
        total_batch_miou += batch_miou

        tqdm_str = 'Loss={:.4f}|mIoU={:.4f}|bat_mIoU={:.4f}'  # 进度条
        tqdm_data.set_postfix_str(
            tqdm_str.format(total_loss / i_batch,
                            get_metrics(total_cm, metrics='mean_iou'),
                            total_batch_miou / i_batch))
        pass
    total_loss /= len(data)  # float 求取一个epoch的loss
    mean_iou = get_metrics(total_cm, metrics='mean_iou')  # float 求mIoU
    total_batch_miou /= len(data)  # 计算所有batch的miou的平均

    # 记录Train日志
    log_str = ('Train Loss: {:.4f}|'
               'Train mIoU: {:.4f}|'
               'Train bat_mIoU: {:.4f}')
    log_str = log_str.format(total_loss, mean_iou, total_batch_miou)
    get_logger().info(log_str)
    return total_loss, mean_iou, total_batch_miou
inference(model, testLoader, config)
del model, testLoader

logging.info("Proceding with volume reconstruction")
for studyIdx in range(0, len(config["volumePaths"])):

    inputStudyPath = config["volumePaths"][studyIdx]
    numberIdx = [(i, c) for i, c in enumerate(inputStudyPath) if c == "/"]
    study = inputStudyPath[numberIdx[-2][0] + 1:numberIdx[-1][0]]
    modality = inputStudyPath[numberIdx[-1][0] + 1:-7]
    outStudyPath = os.path.join(config["outputFolder"], str(study))

    logging.info("Reconstructing %s...", outStudyPath + "/" + modality)
    dh.reconstruct_study(outStudyPath, inputStudyPath, modality, config)
    segImg = dh.read_oriented(outStudyPath + "/" + modality +
                              "_segmentation.nii.gz")

    # If annotations available, obtain metrics
    if config["gtPaths"] != None:

        gtImg = dh.read_oriented(
            config["gtPaths"][studyIdx])  # Read annotations
        # Read nifti file to obtain original affine for evaluation
        inputVolume = nib.load(config["gtPaths"][studyIdx])
        originalAffine = inputVolume.affine

        metrics = tools.get_metrics(gtImg, segImg,
                                    originalAffine.diagonal()[:-1])
        logging.info('Metrics for study: %s', metrics)

logging.info("Finished")
def load_probs(mainFolder,
               outputFolder,
               studies,
               niftiType,
               planes,
               metrics=False,
               segsFolder=None):
    """ Function to aggregate probabilities for T1 modalities """

    for study in studies:
        for run in get_runs4study(study, mainFolder):
            logging.info(
                'Reconstructing study: %s',
                os.path.join(mainFolder, "Plane" + str(planes[0]), run,
                             "Outputs", str(study), "T1..."))
            inputVolume = nib.load(
                os.path.join(mainFolder, "Plane" + str(planes[0]), run,
                             "Outputs", str(study),
                             "T1OP_" + "probs" + str(2) + ".nii.gz"))
            originalAffine = inputVolume.affine
            originalShape = inputVolume.dataobj.shape

            oneHot = np.zeros(
                (5, originalShape[0], originalShape[1], originalShape[2]))

            for modality in ["T1OP_", "T1IP_"]:

                for label in [0, 1, 2, 3, 4]:

                    if 0 in planes:
                        path = os.path.join(
                            mainFolder, "Plane" + str(0), run, "Outputs",
                            str(study),
                            modality + "probs" + str(label) + ".nii.gz")
                        volume = nib.load(path).get_fdata()
                        oneHot[label, :, :, :] += volume

                    if 1 in planes:
                        path = os.path.join(
                            mainFolder, "Plane" + str(1), run, "Outputs",
                            str(study),
                            modality + "probs" + str(label) + ".nii.gz")
                        volume = nib.load(path).get_fdata()
                        oneHot[label, :, :, :] += volume

                    if 2 in planes:
                        path = os.path.join(
                            mainFolder, "Plane" + str(2), run, "Outputs",
                            str(study),
                            modality + "probs" + str(label) + ".nii.gz")
                        volume = nib.load(path).get_fdata()
                        oneHot[label, :, :, :] += volume

            oneHot /= len(planes)
            segm = np.argmax(oneHot, axis=0)

            isoVolNifti = nib.Nifti1Image(segm.astype(np.int16),
                                          originalAffine)
            path = os.path.join(outputFolder, run, str(study))
            if not os.path.exists(path): os.makedirs(path)
            nib.save(isoVolNifti,
                     path + "/" + modality[:2] + "_segmentation.nii.gz")

            if metrics:
                gtImg = read_oriented(segsFolder + "/" + str(study) + "/" +
                                      modality[:2] + "GT.nii.gz")
                segImg = read_oriented(path + "/" + modality[:2] +
                                       "_segmentation.nii.gz")
                metrics = tools.get_metrics(gtImg, segImg)
                logging.info('Metrics for study: %s', metrics)
                logging.info(
                    "####################################################")
def load_all_planes(mainFolder,
                    outputFolder,
                    studies,
                    niftiType,
                    planes,
                    metrics=False,
                    segsFolder=None):
    """ Function to aggregate different planes, works with either segmentations or probabilities """

    for study in studies:
        for modality in ["T1OP_", "T1IP_", "T2_"]:
            for run in get_runs4study(study, mainFolder):

                logging.info(
                    'Reconstructing study: %s',
                    os.path.join(mainFolder, "Plane" + str(0), run, "Outputs",
                                 str(study), modality[:-1] + "..."))
                inputVolume = nib.load(
                    os.path.join(mainFolder, "Plane" + str(0), run, "Outputs",
                                 str(study),
                                 modality + "probs" + str(2) + ".nii.gz"))
                originalAffine = inputVolume.affine
                originalShape = inputVolume.dataobj.shape

                if niftiType == "probs":

                    oneHot = np.zeros((5, originalShape[0], originalShape[1],
                                       originalShape[2]))
                    for label in [0, 1, 2, 3, 4]:

                        if 0 in planes:
                            path = os.path.join(
                                mainFolder, "Plane" + str(0), run, "Outputs",
                                str(study),
                                modality + "probs" + str(label) + ".nii.gz")
                            volume = nib.load(path).get_fdata()
                            oneHot[label, :, :, :] += volume

                        if 1 in planes:
                            path = os.path.join(
                                mainFolder, "Plane" + str(1), run, "Outputs",
                                str(study),
                                modality + "probs" + str(label) + ".nii.gz")
                            volume = nib.load(path).get_fdata()
                            oneHot[label, :, :, :] += volume

                        if 2 in planes:
                            path = os.path.join(
                                mainFolder, "Plane" + str(2), run, "Outputs",
                                str(study),
                                modality + "probs" + str(label) + ".nii.gz")
                            volume = nib.load(path).get_fdata()
                            oneHot[label, :, :, :] += volume

                    oneHot /= len(planes)
                    segm = np.argmax(oneHot, axis=0)

                if niftiType == "segmentation":

                    oneHot = np.zeros((3, 5, originalShape[0],
                                       originalShape[1], originalShape[2]))
                    segm = np.zeros(
                        (originalShape[0], originalShape[1], originalShape[2]))

                    if 0 in planes:
                        path = os.path.join(mainFolder, "Plane" + str(0), run,
                                            "Outputs", str(study),
                                            modality + "segmentation.nii.gz")
                        volImg = nib.load(path).get_fdata()
                        for label in range(1, 5):
                            oneHot[0, label, :, :, :][np.where(
                                volImg == label)] = 1
                            oneHot[0, label, :, :, :] = binary_closing(
                                oneHot[0, label, :, :, :], ball(3))

                    if 1 in planes:
                        path = os.path.join(mainFolder, "Plane" + str(1), run,
                                            "Outputs", str(study),
                                            modality + "segmentation.nii.gz")
                        volImg = nib.load(path).get_fdata()
                        for label in range(1, 5):
                            oneHot[1, label, :, :, :][np.where(
                                volImg == label)] = 1
                            oneHot[1, label, :, :, :] = binary_closing(
                                oneHot[1, label, :, :, :], ball(3))

                    if 2 in planes:
                        path = os.path.join(mainFolder, "Plane" + str(2), run,
                                            "Outputs", str(study),
                                            modality + "segmentation.nii.gz")
                        volImg = nib.load(path).get_fdata()
                        for label in range(1, 5):
                            oneHot[2, label, :, :, :][np.where(
                                volImg == label)] = 1
                            oneHot[2, label, :, :, :] = binary_closing(
                                oneHot[2, label, :, :, :], ball(3))

                    oneHotSum = np.sum(oneHot, axis=0)
                    oneHotSum[np.where(oneHotSum < 2)] = 0
                    oneHotSum[np.where(oneHotSum != 0)] = 1

                    for label in range(1, 5):
                        segm[np.where(
                            oneHotSum[label, :, :, :] == 1)] = 1 * label

                isoVolNifti = nib.Nifti1Image(segm.astype(np.int16),
                                              originalAffine)
                path = os.path.join(outputFolder, run, str(study))
                if not os.path.exists(path): os.makedirs(path)
                nib.save(isoVolNifti,
                         path + "/" + modality + "segmentation.nii.gz")

                if metrics:
                    gtImg = read_oriented(segsFolder + "/" + str(study) + "/" +
                                          modality[:2] + "GT.nii.gz")
                    segImg = read_oriented(path + "/" + modality +
                                           "segmentation.nii.gz")
                    metrics = tools.get_metrics(gtImg, segImg, originalAffine)
                    logging.info('Metrics for study: %s', metrics)
                    logging.info(
                        "####################################################")