Example #1
0
    def __init__(self, image_dir, label_dir, label_mapper, patch_size, transformations=None):
        self.image_dir = Path(image_dir)
        self.label_dir = Path(label_dir)
        self.image_list = list(self.image_dir.glob('*.nii*'))
        self.label_list = list(self.label_dir.glob('*.nii*'))
        self.transformations = transformations
        self.label_mapper = label_mapper
        self.patch_size = patch_size

        # some sanity checks on the data
        assert len(self.image_list) == len(self.label_list), 'Number of images and labels found are not equal.'
        assert all([(i.name == l.name) for i, l in zip(self.image_list, self.label_list)]), 'Image and label names do not correspond'

        # Data transforms + patch selection
        self.train_transforms = transforms.Compose(
            [
                transforms.LoadNiftid(keys=["img", "seg"], as_closest_canonical=True),
                # transforms.adaptor(label_mapper_transform, 'seg'),
                # transforms.AddChannelD(keys=["img", "seg"]),
                transforms.Whiteningd(keys=["img"]),
                transforms.RandSpatialCropD(keys=["img", "seg"], roi_size=self.patch_size, random_center=True,
                                            random_size=False),
                transforms.CastToTypeD(keys=["seg"], dtype=np.long)
            ]
        )
        self.val_transforms = transforms.Compose(
            [
                transforms.LoadNiftid(keys=["img", "seg"], as_closest_canonical=True),
                # transforms.adaptor(label_mapper_transform, 'seg'),
                # transforms.AddChannelD(keys=["img", "seg"]),
                transforms.Whiteningd(keys=["img"]),
                transforms.CastToTypeD(keys=["seg"], dtype=np.long)
            ]
        )
    def setup(self, stage=None):
        # transform
        min_val, max_val, scale_val = -1500, 3000, 1000

        transform = transforms.Compose([
            transforms.AddChannel(),
            transforms.ThresholdIntensity(threshold=max_val,
                                          cval=max_val,
                                          above=False),
            transforms.ThresholdIntensity(threshold=min_val,
                                          cval=min_val,
                                          above=True),
            transforms.ScaleIntensity(minv=None,
                                      maxv=None,
                                      factor=(-1 + 1 / scale_val)),
            transforms.ShiftIntensity(offset=1),
            transforms.ToTensor(),
            DepthPadAndCrop(output_depth=128
                            ),  # needs to be last because it outputs the label
        ])

        if self.rescale_input:
            transform = transforms.Compose(
                [transform,
                 Interpolate(size=self.rescale_input, mode='area')])

        dataset = CTScanDataset(self.path,
                                transform=transform,
                                spacing=(0.976, 0.976, 3))

        train_len = int(len(dataset) * self.train_frac)
        val_len = len(dataset) - train_len

        # train/val split
        train_split, val_split = random_split(dataset, [train_len, val_len])

        # assign to use in dataloaders
        self.train_dataset = train_split
        self.train_len = train_len
        self.train_batch_size = self.batch_size

        self.val_dataset = val_split
        self.val_len = val_len
        self.val_batch_size = self.batch_size
Example #3
0
 def __init__(self, cfg, is_train=True):
     super().__init__()
     if isinstance(cfg, Box):
         raise ValueError('Pass a dict instead')
     self.save_hyperparameters('cfg', 'is_train')
     self.cfg = Box(cfg)
     self.batch_size = self.cfg.solver.ims_per_batch
     self.learning_rate = self.cfg.solver.default_lr
     self.cfg = Box(cfg)
     self.criterion = losses.Loss(**self.cfg.solver.loss.params)
     self.postprocessing = tf.Compose([
         tf.Activations(sigmoid=True),
         tf.AsDiscrete(threshold_values=True)
     ])
     self.metrics = {
         'dice': smp.utils.metrics.Fscore(),
         'iou': smp.utils.metrics.IoU()
     }
     self.net = architectures.build(self.cfg.model.name,
                                    **self.cfg.model.get('parameters', {}))
Example #4
0
 def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.0)):
     """
     Read and transform Prostate_T2W_AX_1.nii
     Args:
         translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input
                 image. Defaults to no translation.
         rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D.
                 Defaults to no rotation.
     Returns:
         numpy array of shape HWD
     """
     transform_list = [
         transforms.LoadImaged(keys="img"),
         transforms.Affined(
             keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None
         ),
         transforms.NormalizeIntensityd(keys=["img"]),
     ]
     transformation = transforms.Compose(transform_list)
     return transformation({"img": FILE_PATH})["img"]
Example #5
0
def read_single_dict():
    # image = '../../dataset/LCTSC_1/test_nii_1/LCTSC-Test-S1-101.nii.gz'
    # label = '../../dataset/LCTSC_1/test_mask_nii/LCTSC-Test-S1-101.nii.gz'
    # set image path
    image = 'D:\\code\\U-net\\data\\imgs\\COVID-19-CT-Seg_20cases\\coronacases_org_001.nii'
    # set label path
    label = 'D:\\code\\U-net\\data\\masks\\Lung_and_Infection_Mask\\coronacases_001.nii'
    keys = ('image', 'label')
    mn_tfs = mn_tf.Compose([
        mn_tf.LoadNiftiD(keys),
        # mn_tf.AsChannelFirstD('image'),
        mn_tf.AddChannelD(keys),
        # mn_tf.SpacingD(keys, pixdim=(1., 1., 1.), mode=('bilinear', 'nearest')),
        mn_tf.OrientationD(keys, axcodes='RAS'),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=600,
                                  above=False,
                                  cval=600),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=-1000,
                                  above=True,
                                  cval=-1000),

        #等比例压缩,压缩到0到255
        mn_tf.ScaleIntensityD('image', minv=0.0, maxv=255.0),  # show image
        mn_tf.ScaleIntensityD('image'),

        #改到512*512
        mn_tf.ResizeD(keys, (512, 512, -1), mode=('trilinear', 'nearest')),
        mn_tf.AsChannelFirstD(keys),
        # mn_tf.RandAffineD(keys, spatial_size=(-1, -1, -1),
        #                   rotate_range=(0, 0, np.pi / 2),
        #                   scale_range=(0.1, 0.1),
        #                   mode=('bilinear', 'nearest'),
        #                   prob=1.0),
        mn_tf.ToTensorD(keys)
    ])
    data_dict = mn_tfs({'image': image, 'label': label})
    print(data_dict['image'].shape, data_dict['label'].shape)
    slices = data_dict['image']
    masks = data_dict['label']

    # x = [1, 2, 3, 4, 5]
    #
    # y = [10, 5, 15, 10, 20]
    #
    # plt.plot(x, y, 'ro-', color='blue')
    #
    # plt.savefig('testblueline.jpg')

    plt.show()

    for idx, item in enumerate(zip(slices, masks)):
        image = item[0][0]
        label = item[1][0] * 255
        #index = torch.where(label == 1275)
        # plt.savefig('D:\code\Pytorch-UNet-master\data_test\mask'+str(idx)+'.jpg')
        # path_test =  'D:\\code\\U-net\\train_images_1\\1\\img\\' + str(idx) + '_image' + '.png'
        # print(path_test)
        plt.imsave('D:\\code\\U-net\\train_images_1\\0\\img\\' + str(idx) +
                   '.png',
                   image,
                   cmap='gray')
        plt.imsave('D:\\code\\U-net\\train_images_1\\0\\mask\\' + str(idx) +
                   '_mask' + '.png',
                   label,
                   cmap='gray')
        if idx >= 120 and idx <= 121:
            # print(image)
            print(image.min(), image.max())
            plt.imshow(image, cmap='gray')
            # plt.imshow(image)
            plt.show()
            plt.close()
            plt.imshow(label, cmap='gray')
            # plt.imshow(label)
            plt.show()
            plt.close()
Example #6
0
def get_train_loaders_1(config):
    assert 'loaders' in config, 'Could not find data loaders configuration'
    loaders_config = config['loaders']
    logger.info('Creating training and validation set loaders...')
    dataset_cls_str = loaders_config.get('dataset', None)

    keys = ('image', 'label')
    mn_train_tfs = mn_tf.Compose([
        mn_tf.LoadNiftiD(keys),
        # mn_tf.AsChannelFirstD('image'),
        mn_tf.AddChannelD(keys),
        # mn_tf.SpacingD(keys, pixdim=(1., 1., 1.), mode=('bilinear', 'nearest')),
        mn_tf.OrientationD(keys, axcodes='RAS'),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=600,
                                  above=False,
                                  cval=600),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=-1000,
                                  above=True,
                                  cval=-1000),
        # mn_tf.ScaleIntensityD('image', minv=0.0, maxv=225.0), # show image
        mn_tf.ScaleIntensityD('image'),
        mn_tf.ResizeD(keys, (512, 512, 128), mode=('trilinear', 'nearest')),
        mn_tf.AsChannelFirstD(keys),
        mn_tf.RandAffineD(keys,
                          spatial_size=(-1, -1, -1),
                          rotate_range=(0, 0, np.pi / 2),
                          scale_range=(0.1, 0.1),
                          mode=('bilinear', 'nearest'),
                          prob=1.0),
        mn_tf.ToTensorD(keys)
    ])

    mn_val_tfs = mn_tf.Compose([
        mn_tf.LoadNiftiD(keys),
        mn_tf.AddChannelD(keys),
        mn_tf.OrientationD(keys, axcodes='RAS'),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=600,
                                  above=False,
                                  cval=600),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=-1000,
                                  above=True,
                                  cval=-1000),
        mn_tf.ScaleIntensityD('image'),
        mn_tf.ResizeD(keys, (512, 512, 128), mode=('trilinear', 'nearest')),
        mn_tf.AsChannelFirstD(keys),
        mn_tf.ToTensorD(keys)
    ])
    train_datasets = lung.LungDataset(root_dir=loaders_config['root_dir'],
                                      transforms=mn_train_tfs,
                                      train=True)
    val_datasets = lung.LungDataset(root_dir=loaders_config['root_dir'],
                                    transforms=mn_val_tfs,
                                    train=False)
    num_workers = loaders_config.get('num_workers', 1)
    logger.info(f'Number of workers for train/val dataloader: {num_workers}')
    batch_size = loaders_config.get('batch_size', 1)
    logger.info(f'Batch size for train/val loader: {batch_size}')
    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
    return {
        'train':
        DataLoader(dataset=train_datasets,
                   batch_size=batch_size,
                   shuffle=True,
                   num_workers=num_workers),
        'val':
        DataLoader(dataset=val_datasets,
                   batch_size=batch_size,
                   shuffle=True,
                   num_workers=num_workers)
    }
Example #7
0
if __name__ == '__main__':
    keys = ('image', 'label')
    mn_tfs = mn_tf.Compose([
        mn_tf.LoadNiftiD(keys),
        # mn_tf.AsChannelFirstD('image'),
        mn_tf.AddChannelD(keys),
        # mn_tf.SpacingD(keys, pixdim=(1., 1., 1.), mode=('bilinear', 'nearest')),
        mn_tf.OrientationD(keys, axcodes='RAS'),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=600,
                                  above=False,
                                  cval=600),
        mn_tf.ThresholdIntensityD('image',
                                  threshold=-1000,
                                  above=True,
                                  cval=-1000),
        # mn_tf.ScaleIntensityD('image', minv=0.0, maxv=225.0), # show image
        mn_tf.ScaleIntensityD('image'),
        mn_tf.ResizeD(keys, (512, 512, 128), mode=('trilinear', 'nearest')),
        mn_tf.AsChannelFirstD(keys),
        mn_tf.RandAffineD(keys,
                          spatial_size=(-1, -1, -1),
                          rotate_range=(0, 0, np.pi / 2),
                          scale_range=(0.1, 0.1),
                          mode=('bilinear', 'nearest'),
                          prob=1.0),
        mn_tf.ToTensorD(keys)
    ])
    dataset = LungDataset(root_dir='../../../../dataset/LCTSC_1',
                          transforms=mn_tfs,
                          train=False)
Example #8
0
def parsingInputs(inps, debug):
    CACHE_PATH = None
    INPUT_PATH = inps['inputpath']
    if 'cachepath' in inps.keys():
        CACHE_PATH = inps['cachepath']


    # Load dataset and convert to dictionary
    traind, validated, testd = convertInputsToDictionaies(INPUT_PATH, debug)
    res_loader = {}
    params = initParams()
    params['debug'] = debug
    updateParams(inps['model'], params, 'model')

    statuslist = []
    if 'train' in inps.keys():
        statuslist.append('train')
        if validated == []:
            params['val_percent_check']=0
        else:
            statuslist.append('validate')
    if 'test' in inps.keys():
        statuslist.append('test')


    for status in statuslist:
        input_defaultT = None
        input_augmentation = None

        # Parsing transformation parameters
        vars()[status+'_transformation'] = TRANSFORMATION()
        if status == 'test' and type(testd[0]['SEGM']) == float:
            params['testSegm'] = False
            vars()[status+'_transformation'].keys = ['IMAGE']
            vars()[status+'_transformation'].prefix = (status + '_image',)
        else:
            vars()[status+'_transformation'].keys = ['IMAGE', 'SEGM']
            vars()[status+'_transformation'].prefix = (status + '_image', status + '_segm')


        if status in inps.keys() and 'defaulttransformation' in inps[status].keys():
            input_defaultT = inps[status].pop('defaulttransformation')
        elif 'defaulttransformation' in inps.keys():
            input_defaultT = inps['defaulttransformation']
        if status in inps.keys() and 'augmentation' in inps[status].keys():
            input_augmentation = inps[status].pop('augmentation')
        elif 'augmentation' in inps.keys():
            input_augmentation = inps['augmentation']

        vars()[status+'_transformation'].parsingtransformations(status, debug, input_defaultT, input_augmentation)

        trans_lists = getattr(vars()[status+'_transformation'], 'comfuncs')


        # Parsing status specific arguments
        if status in inps.keys():
            updateParams(inps[status], params, status)
        else:
            updateParams({}, params, status)

        if CACHE_PATH is None:
            vars()[status+'_ds'] = monaiData.Dataset(
                data = vars()[status+'d'], \
                transform = monaiTrans.Compose(trans_lists)
            )
        else:
            # Maintain a consistent CACHE_PATH if you want mulitple programs to use this
            vars()[status+'_ds'] = monaiData.PersistentDataset(
                data = vars()[status+'d'], \
                transform = monaiTrans.Compose(trans_lists), \
                cache_dir=CACHE_PATH\
            )

        vars()[status+'_loader'] = torchDataloader(
            vars()[status+'_ds'], \
            batch_size = params['batch_size'], \
            shuffle =params['shuffle'], \
            num_workers =params['num_workers'], \
            collate_fn = params['collate_fn']\
        )

        res_loader[status+'_loader'] = vars()[status+'_loader']


    # TODO: user should keep the parameteres below consistent in training / testing
    try:
        params['model']['patch_size'] = vars()['train_transformation'].patch_size
        params['orientation'] = vars()['train_transformation'].orientation
        params['spacing'] = vars()['train_transformation'].spacing
    except:
        pass
    try:
        params['orientation'] = vars()['test_transformation'].orientation
        params['spacing'] = vars()['test_transformation'].spacing
    except:
        pass

    return params, res_loader
Example #9
0
def get_dataflow(seed, data_dir, cache_dir, batch_size):
    img = nib.load(str(data_dir / "average_smwc1.nii"))
    img_data_1 = img.get_fdata()
    img_data_1 = np.expand_dims(img_data_1, axis=0)

    img = nib.load(str(data_dir / "average_smwc2.nii"))
    img_data_2 = img.get_fdata()
    img_data_2 = np.expand_dims(img_data_2, axis=0)

    img = nib.load(str(data_dir / "average_smwc3.nii"))
    img_data_3 = img.get_fdata()
    img_data_3 = np.expand_dims(img_data_3, axis=0)

    mask = np.concatenate((img_data_1, img_data_2, img_data_3))
    mask[mask > 0.3] = 1
    mask[mask <= 0.3] = 0

    # Define transformations
    train_transforms = transforms.Compose([
        transforms.LoadNiftid(keys=["c1", "c2", "c3"]),
        transforms.AddChanneld(keys=["c1", "c2", "c3"]),
        transforms.ConcatItemsd(keys=["c1", "c2", "c3"], name="img"),
        transforms.MaskIntensityd(keys=["img"], mask_data=mask),
        transforms.ScaleIntensityd(keys="img"),
        transforms.ToTensord(keys=["img", "label"])
    ])

    val_transforms = transforms.Compose([
        transforms.LoadNiftid(keys=["c1", "c2", "c3"]),
        transforms.AddChanneld(keys=["c1", "c2", "c3"]),
        transforms.ConcatItemsd(keys=["c1", "c2", "c3"], name="img"),
        transforms.MaskIntensityd(keys=["img"], mask_data=mask),
        transforms.ScaleIntensityd(keys="img"),
        transforms.ToTensord(keys=["img", "label"])
    ])

    # Get img paths
    df = pd.read_csv(data_dir / "banc2019_training_dataset.csv")
    df = df.sample(frac=1, random_state=seed)
    df["NormAge"] = (((df["Age"] - 18) / (92 - 18)) * 2) - 1
    data_dicts = []
    for index, row in df.iterrows():
        study_dir = data_dir / row["Study"] / "derivatives" / "spm"
        subj = list(study_dir.glob(f"sub-{row['Subject']}"))

        if subj == []:
            subj = list(study_dir.glob(f"*sub-{row['Subject']}*"))
            if subj == []:
                subj = list(
                    study_dir.glob(f"*sub-{row['Subject'].rstrip('_S1')}*"))
                if subj == []:
                    if row["Study"] == "SALD":
                        subj = list(
                            study_dir.glob(f"sub-{int(row['Subject']):06d}*"))
                        if subj == []:
                            print(f"{row['Study']} {row['Subject']}")
                            continue
                    else:
                        print(f"{row['Study']} {row['Subject']}")
                        continue

        c1_img = list(subj[0].glob("./smwc1*"))
        c2_img = list(subj[0].glob("./smwc2*"))
        c3_img = list(subj[0].glob("./smwc3*"))

        if c1_img == []:
            print(f"{row['Study']} {row['Subject']}")
            continue
        if c2_img == []:
            print(f"{row['Study']} {row['Subject']}")
            continue
        if c3_img == []:
            print(f"{row['Study']} {row['Subject']}")
            continue

        data_dicts.append({
            "c1": str(c1_img[0]),
            "c2": str(c2_img[0]),
            "c3": str(c3_img[0]),
            "label": row["NormAge"]
        })

    print(f"Found {len(data_dicts)} subjects.")
    val_size = len(data_dicts) // 10
    # Create datasets and dataloaders
    train_ds = PersistentDataset(data=data_dicts[:-val_size],
                                 transform=train_transforms,
                                 cache_dir=cache_dir)
    # train_ds = Dataset(data=data_dicts[:-val_size], transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=4,
                              collate_fn=list_data_collate)

    val_ds = PersistentDataset(data=data_dicts[-val_size:],
                               transform=val_transforms,
                               cache_dir=cache_dir)
    # val_ds = Dataset(data=data_dicts[-val_size:], transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4)

    return train_loader, val_loader