Beispiel #1
0
def postprocess():
    # LOAD DATA
    id = 0
    pred = nib.load(DIR + get_file('pred', id)).get_fdata()
    pred = np.round(pred)
    label = nib.load(DIR + get_file('lbl', id)).get_fdata()
    label = np.nan_to_num(label)
    label = np.round(label)

    dice_f = dice_score(pred, label)
    print('Loaded data', pred.shape, 'DICE:', dice_f)

    # CONNECTIVITY
    # pred, max_label = postprocess_connectivity(pred)

    dice_f = dice_score(pred, label)
    # print('Connectivity | max:', max_label, 'DICE:', dice_f)
    # list_connectivities(labels_out)

    # CREATE AND SHOW SURFACE
    lbl_v = Volume(label)
    lbl_s = lbl_v.isosurface(threshold=[0, 0.5, 1])
    pred_v = Volume(pred)
    pred_s = pred_v.isosurface(threshold=[0, 0.5, 1])
    # s.alpha(0.5).lw(0.1)

    show(lbl_s, pred_s, N=2, axes=8)
def evaluate_saved_model(model_config,
                         split='validation',
                         model_path=None,
                         data_path=None,
                         save_directory=None,
                         save_nii=False,
                         save_npz=False):
    # Load options
    json_opts = json_file_to_pyobj(model_config)
    train_opts = json_opts.training
    model_opts = json_opts.model
    data_path_opts = json_opts.data_path

    if model_path is not None:
        model_opts = json_opts.model._replace(
            path_pre_trained_model=model_path)

    model_opts = model_opts._replace(gpu_ids=[])

    # Setup the NN Model
    model = get_model(model_opts)
    if save_directory is None:
        save_directory = os.path.join(os.path.dirname(model_config),
                                      split + '_evaluation')
    mkdir(save_directory)

    # Setup Dataset and Augmentation
    ds_class = get_dataset(train_opts.arch_type)
    if data_path is None:
        data_path = get_dataset_path(train_opts.arch_type, data_path_opts)
    dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                                   opts=json_opts.augmentation)

    # Setup channels
    channels = json_opts.data_opts.channels
    if len(channels) != json_opts.model.input_nc \
            or len(channels) != getattr(json_opts.augmentation, train_opts.arch_type).scale_size[-1]:
        raise Exception(
            'Number of data channels must match number of model channels, and patch and scale size dimensions'
        )

    # Setup Data Loader
    split_opts = json_opts.data_split
    dataset = ds_class(data_path,
                       split=split,
                       transform=dataset_transform['valid'],
                       preload_data=train_opts.preloadData,
                       train_size=split_opts.train_size,
                       test_size=split_opts.test_size,
                       valid_size=split_opts.validation_size,
                       split_seed=split_opts.seed,
                       channels=channels)
    data_loader = DataLoader(dataset=dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=False)

    # Visualisation Parameters
    # visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)

    # Setup stats logger
    stat_logger = StatLogger()

    if save_npz:
        all_predicted = []

    # test
    for iteration, data in tqdm(enumerate(data_loader, 1)):
        model.set_input(data[0], data[1])
        model.test()

        input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32)
        prior_arr = np.squeeze(data[0].cpu().numpy())[5].astype(np.int16)
        prior_arr[prior_arr > 0] = 1
        label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16)
        ids = dataset.get_ids(data[2])
        output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
            np.int16)

        # If there is a label image - compute statistics
        dice_vals = dice_score(label_arr, output_arr, n_class=int(2))
        single_class_dice = single_class_dice_score(label_arr, output_arr)
        md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=1)
        precision, recall = precision_and_recall(label_arr,
                                                 output_arr,
                                                 n_class=int(2))
        sp = specificity(label_arr, output_arr)
        jaccard = jaccard_score(label_arr.flatten(), output_arr.flatten())

        # compute stats for the prior that is used
        prior_dice = single_class_dice_score(label_arr, prior_arr)
        prior_precision, prior_recall = precision_and_recall(label_arr,
                                                             prior_arr,
                                                             n_class=int(2))

        stat_logger.update(split=split,
                           input_dict={
                               'img_name': ids[0],
                               'dice_bg': dice_vals[0],
                               'dice_les': dice_vals[1],
                               'dice2_les': single_class_dice,
                               'prec_les': precision[1],
                               'reca_les': recall[1],
                               'specificity': sp,
                               'md_les': md,
                               'hd_les': hd,
                               'jaccard': jaccard,
                               'dice_prior': prior_dice,
                               'prec_prior': prior_precision[1],
                               'reca_prior': prior_recall[1]
                           })

        if save_nii:
            # Write a nifti image
            import SimpleITK as sitk
            input_img = sitk.GetImageFromArray(
                np.transpose(input_arr[0], (2, 1, 0)))
            input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            cbf_img = sitk.GetImageFromArray(
                np.transpose(input_arr[1], (2, 1, 0)))
            cbf_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            prior_img = sitk.GetImageFromArray(
                np.transpose(input_arr[5], (2, 1, 0)))
            prior_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            label_img = sitk.GetImageFromArray(
                np.transpose(label_arr, (2, 1, 0)))
            label_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
            predi_img = sitk.GetImageFromArray(
                np.transpose(output_arr, (2, 1, 0)))
            predi_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])

            sitk.WriteImage(
                input_img,
                os.path.join(save_directory,
                             '{}_img.nii.gz'.format(iteration)))
            sitk.WriteImage(
                cbf_img,
                os.path.join(save_directory,
                             '{}_cbf.nii.gz'.format(iteration)))
            sitk.WriteImage(
                prior_img,
                os.path.join(save_directory,
                             '{}_prior.nii.gz'.format(iteration)))
            sitk.WriteImage(
                label_img,
                os.path.join(save_directory,
                             '{}_lbl.nii.gz'.format(iteration)))
            sitk.WriteImage(
                predi_img,
                os.path.join(save_directory,
                             '{}_pred.nii.gz'.format(iteration)))

        if save_npz:
            all_predicted.append(output_arr)

    stat_logger.statlogger2csv(split=split,
                               out_csv_name=os.path.join(
                                   save_directory, split + '_stats.csv'))
    for key, (mean_val,
              std_val) in stat_logger.get_errors(split=split).items():
        print('-', key, ': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val), '-')

    if save_npz:
        np.savez_compressed(os.path.join(save_directory, 'predictions.npz'),
                            predicted=np.array(all_predicted))
def validation(json_name):
    # Load options
    json_opts = json_file_to_pyobj(json_name)
    train_opts = json_opts.training

    # Setup the NN Model
    model = get_model(json_opts.model)
    save_directory = os.path.join(model.save_dir, train_opts.arch_type)
    mkdirfun(save_directory)

    # Setup Dataset and Augmentation
    dataset_class = get_dataset(train_opts.arch_type)
    dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
    dataset_transform = get_dataset_transformation(train_opts.arch_type,
                                                   opts=json_opts.augmentation)

    # Setup Data Loader
    dataset = dataset_class(dataset_path,
                            split='validation',
                            transform=dataset_transform['valid'])
    data_loader = DataLoader(dataset=dataset,
                             num_workers=8,
                             batch_size=1,
                             shuffle=False)

    # Visualisation Parameters
    #visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)

    # Setup stats logger
    stat_logger = StatLogger()

    # test
    for iteration, data in enumerate(data_loader, 1):
        model.set_input(data[0], data[1])
        model.test()

        input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32)
        label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16)
        output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
            np.int16)

        # If there is a label image - compute statistics
        dice_vals = dice_score(label_arr, output_arr, n_class=int(4))
        md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=2)
        precision, recall = precision_and_recall(label_arr,
                                                 output_arr,
                                                 n_class=int(4))
        stat_logger.update(split='test',
                           input_dict={
                               'img_name': '',
                               'dice_LV': dice_vals[1],
                               'dice_MY': dice_vals[2],
                               'dice_RV': dice_vals[3],
                               'prec_MYO': precision[2],
                               'reca_MYO': recall[2],
                               'md_MYO': md,
                               'hd_MYO': hd
                           })

        # Write a nifti image
        import SimpleITK as sitk
        input_img = sitk.GetImageFromArray(np.transpose(input_arr, (2, 1, 0)))
        input_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        label_img = sitk.GetImageFromArray(np.transpose(label_arr, (2, 1, 0)))
        label_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])
        predi_img = sitk.GetImageFromArray(np.transpose(output_arr, (2, 1, 0)))
        predi_img.SetDirection([-1, 0, 0, 0, -1, 0, 0, 0, 1])

        sitk.WriteImage(
            input_img,
            os.path.join(save_directory, '{}_img.nii.gz'.format(iteration)))
        sitk.WriteImage(
            label_img,
            os.path.join(save_directory, '{}_lbl.nii.gz'.format(iteration)))
        sitk.WriteImage(
            predi_img,
            os.path.join(save_directory, '{}_pred.nii.gz'.format(iteration)))

    stat_logger.statlogger2csv(split='test',
                               out_csv_name=os.path.join(
                                   save_directory, 'stats.csv'))
    for key, (mean_val,
              std_val) in stat_logger.get_errors(split='test').items():
        print('-', key, ': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val), '-')
def fit_one_epoch(net, epoch, epoch_size_train, epoch_size_val, gen_train, gen_val, Epoch, cuda):
    """
    训练一个世代epoch
    net: 网络模型
    epoch: 一个世代epoch
    epoch_size_train: 训练迭代次数iters
    epoch_size_val: 验证迭代次数iters
    gen_train: 训练数据集
    gen_val: 验证数据集
    Epoch: 总的迭代次数Epoch
    cuda: 是否使用GPU
    """
    train_total_loss = 0
    train_total_dice_score = 0
    val_total_loss = 0
    val_total_dice_score = 0
    # 开启训练模式
    net.train()
    print('Start Training')
    start_time = time.time()
    with tqdm(total=epoch_size_train, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen_train):  
            if iteration >= epoch_size_train: 
                break
            imgs, pngs, labels = batch
            with torch.no_grad():
                imgs = Variable(torch.from_numpy(imgs).type(torch.FloatTensor))
                pngs = Variable(torch.from_numpy(pngs).type(torch.FloatTensor)).long()
                labels = Variable(torch.from_numpy(labels).type(torch.FloatTensor))
                if cuda:
                    imgs = imgs.cuda()
                    pngs = pngs.cuda()
                    labels = labels.cuda()
            # 梯度初始化置零
            optimizer.zero_grad()
            # 前向传播,网络输出
            outputs = net(imgs)
            # 计算损失 一次iter即一个batchsize的loss 
            loss = CE_Loss(outputs, pngs, num_classes=NUM_CLASSES)
            if dice_loss:
                main_dice = Dice_loss(outputs, labels)
                loss = loss + main_dice
            # 计算f_score
            with torch.no_grad():
                dice = dice_score(outputs, labels)
            # loss反向传播求梯度
            loss.backward()
            # 更新所有参数
            optimizer.step()
            train_total_loss += loss.item() 
            train_total_dice_score += dice.item()
            waste_time = time.time() - start_time
            pbar.set_postfix(**{'train_loss': train_total_loss / (iteration + 1), 
                                'train_dice_score': train_total_dice_score / (iteration + 1),
                                's/step'    : waste_time,
                                'lr'        : get_lr(optimizer)})
            pbar.update(1)
            start_time = time.time()
    print('Finish Training')
    # 开启验证模式
    net.eval()
    print('Start Validation')
    with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen_val):
            if iteration >= epoch_size_val:
                break
            imgs, pngs, labels = batch
            with torch.no_grad():
                imgs = Variable(torch.from_numpy(imgs).type(torch.FloatTensor))
                pngs = Variable(torch.from_numpy(pngs).type(torch.FloatTensor)).long()
                labels = Variable(torch.from_numpy(labels).type(torch.FloatTensor))
                if cuda:
                    imgs = imgs.cuda()
                    pngs = pngs.cuda()
                    labels = labels.cuda()
                outputs  = net(imgs)
                val_loss = CE_Loss(outputs, pngs, num_classes=NUM_CLASSES)
                if dice_loss:
                    main_dice = Dice_loss(outputs, labels)
                    val_loss  = val_loss + main_dice
                # 计算dice_score
                dice = dice_score(outputs, labels)
                val_total_loss += val_loss.item()
                val_total_dice_score += dice.item()
            pbar.set_postfix(**{'val_loss'  : val_total_loss / (iteration + 1),
                                'val_dice_score' : val_total_dice_score / (iteration + 1),
                                'lr'        : get_lr(optimizer)})
            pbar.update(1)
    print('Finish Validation')
    print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
    print('Train Loss: %.4f || Val Loss: %.4f || Train Dice: %.4f || Val Dice: %.4f' % (train_total_loss/(epoch_size_train+1), val_total_loss/(epoch_size_val+1), train_total_dice_score/(epoch_size_train+1), val_total_dice_score/(epoch_size_val+1)))
    print('Saving state, epoch:', str(epoch+1))
    torch.save(model.state_dict(), '图像分割/FCN/logs_fcn_resnet50_360/Epoch%d-Train_Loss%.4f-Val_Loss%.4f-Train_Dice%.4f-Val_Dice%.4f.pth'%((epoch+1), train_total_loss/(epoch_size_train+1), val_total_loss/(epoch_size_val+1), train_total_dice_score/(epoch_size_train+1), val_total_dice_score/(epoch_size_val+1)))
Beispiel #5
0
def run_seg(config_file_seg):
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # ------------------------------------------------------------------------------------------------------------
    # 2. segmentation inference
    # ------------------------------------------------------------------------------------------------------------
    config = load_config(config_file_seg)

    validloader = make_loader(
        data_folder=config.data.train_dir,
        df_path=config.data.train_df_path,
        phase='valid',
        batch_size=config.train.batch_size,
        num_workers=config.num_workers,
        idx_fold=config.data.params.idx_fold,
        transforms=get_transforms(config.transforms.test),
        num_classes=config.data.num_classes,
    )

    # create segmentation model with pre-trained encoder
    model = getattr(smp, config.model.arch)(
        encoder_name=config.model.encoder,
        encoder_weights=config.model.pretrained,
        classes=config.data.num_classes,
        activation=None,
    )
    model.to(config.device)
    model.eval()
    checkpoint = load_checkpoint(f"{config.work_dir}/checkpoints/best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    all_dice = {}
    min_sizes = [100, 300, 500, 750, 1000, 1500, 2000, 3000]
    for min_size in min_sizes:
        all_dice[min_size] = {}
        for cls in range(config.data.num_classes):
            all_dice[min_size][cls] = []

    with torch.no_grad():
        for i, (batch_images, batch_masks) in enumerate(tqdm(validloader)):
            batch_images = batch_images.to(config.device)
            batch_preds = predict_batch(model,
                                        batch_images,
                                        tta=config.test.tta)

            batch_masks = batch_masks.cpu().numpy()

            for masks, preds in zip(batch_masks, batch_preds):
                for cls in range(config.data.num_classes):
                    for min_size in min_sizes:
                        pred, _ = post_process(preds[cls, :, :],
                                               config.test.best_threshold,
                                               min_size)
                        mask = masks[cls, :, :]
                        all_dice[min_size][cls].append(dice_score(pred, mask))

    for cls in range(config.data.num_classes):
        for min_size in min_sizes:
            all_dice[min_size][cls] = sum(all_dice[min_size][cls]) / len(
                all_dice[min_size][cls])
            dict_to_json(all_dice, config.work_dir + '/threshold_search.json')
            if config.data.num_classes == 4:
                defect_class = cls + 1
            else:
                defect_class = cls
            print('average dice score for class{} for min_size {}: {}'.format(
                defect_class, min_size, all_dice[min_size][cls]))
def validation(config_file_seg):

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    config = load_config(config_file_seg)
    if 'COLAB_GPU' in os.environ:
        config.work_dir = '/content/drive/My Drive/kaggle_cloud/' + config.work_dir
    elif 'KAGGLE_WORKING_DIR' in os.environ:
        config.work_dir = '/kaggle/working/' + config.work_dir

    validloader = make_loader(
        data_folder=config.data.train_dir,
        df_path=config.data.train_df_path,
        phase='valid',
        img_size=(config.data.height, config.data.width),
        batch_size=config.test.batch_size,
        num_workers=config.num_workers,
        idx_fold=config.data.params.idx_fold,
        transforms=get_transforms(config.transforms.test),
        num_classes=config.data.num_classes,
    )

    model = load_model(config_file_seg)

    min_sizes = np.arange(0, 20000, 5000)
    label_thresholds = [0.6, 0.7, 0.8]
    mask_thresholds = [0.2, 0.3, 0.4]
    all_dice = np.zeros(
        (4, len(label_thresholds), len(mask_thresholds), len(min_sizes)))
    count = 0

    with torch.no_grad():
        for i, (batch_images, batch_masks) in enumerate(tqdm(validloader)):
            batch_images = batch_images.to(config.device)
            batch_preds = predict_batch(model,
                                        batch_images,
                                        tta=config.test.tta)

            batch_labels = torch.nn.functional.adaptive_max_pool2d(
                torch.sigmoid(torch.Tensor(batch_preds)),
                1).view(batch_preds.shape[0], -1)

            batch_masks = batch_masks.cpu().numpy()
            batch_labels = batch_labels.cpu().numpy()

            batch_masks = resize_batch_images(batch_masks, SUB_HEIGHT,
                                              SUB_WIDTH)
            batch_preds = resize_batch_images(batch_preds, SUB_HEIGHT,
                                              SUB_WIDTH)

            for labels, masks, preds in zip(batch_labels, batch_masks,
                                            batch_preds):
                for cls in range(config.data.num_classes):
                    for i, label_th in enumerate(label_thresholds):
                        for j, mask_th in enumerate(mask_thresholds):
                            for k, min_size in enumerate(min_sizes):
                                if labels[cls] <= label_th:
                                    pred = np.zeros(preds[cls, :, :].shape)
                                else:
                                    pred, _ = post_process(preds[cls, :, :],
                                                           mask_th,
                                                           min_size,
                                                           height=SUB_HEIGHT,
                                                           width=SUB_WIDTH)
                                mask = masks[cls, :, :]

                                dice = dice_score(pred, mask)
                                all_dice[cls, i, j, k] += dice
                count += 1

    all_dice = all_dice / (count)
    np.save('all_dice', all_dice)

    parameters = {}
    parameters['label_thresholds'] = []
    parameters['mask_thresholds'] = []
    parameters['min_sizes'] = []
    parameters['dice'] = []
    cv_score = 0

    for cls in range(4):
        i, j, k = np.where((all_dice[cls] == all_dice[cls].max()))
        parameters['label_thresholds'].append(float(label_thresholds[i[0]]))
        parameters['mask_thresholds'].append(float(mask_thresholds[j[0]]))
        parameters['min_sizes'].append(int(min_sizes[k[0]]))
        parameters['dice'].append(float(all_dice[cls].max()))
        cv_score += all_dice[cls].max() / 4

    print('cv_score:', cv_score)
    dict_to_json(parameters, config.work_dir + '/parameters.json')
    print(pd.DataFrame(parameters))
Beispiel #7
0
 def test_dice_score(self):
     a = np.array([1, 0, 0, 1])
     b = np.array([1, 1, 0, 0])
     iou = metrics.dice_score(a, b)
     expected = 0.5
     np.testing.assert_almost_equal(iou, expected, decimal=3)