示例#1
0
 def __create_dataset(self):
     root_dir = self.config['dataset']['root_dir']
     train_csv = self.config['dataset'].get('train_csv', None)
     valid_csv = self.config['dataset'].get('valid_csv', None)
     test_csv = self.config['dataset'].get('test_csv', None)
     modal_num = self.config['dataset']['modal_num']
     if (self.stage == 'train'):
         transform_names = self.config['dataset']['train_transform']
         validtransform_names = self.config['dataset']['valid_transform']
         self.validtransform_list = [get_transform(name, self.config['dataset']) \
         for name in validtransform_names if name != 'RegionSwop']
     else:
         transform_names = self.config['dataset']['test_transform']
     self.transform_list = [get_transform(name, self.config['dataset']) \
         for name in transform_names if name != 'RegionSwop']
     if ('RegionSwop' in transform_names):
         self.region_swop = get_transform('RegionSwop',
                                          self.config['dataset'])
     else:
         self.region_swop = None
     if (self.stage == 'train'):
         train_dataset = NiftyDataset(root_dir=root_dir,
                                      csv_file=train_csv,
                                      modal_num=modal_num,
                                      with_label=True,
                                      transform=transforms.Compose(
                                          self.transform_list))
         valid_dataset = NiftyDataset(root_dir=root_dir,
                                      csv_file=valid_csv,
                                      modal_num=modal_num,
                                      with_label=True,
                                      transform=transforms.Compose(
                                          self.validtransform_list))
         batch_size = self.config['training']['batch_size']
         self.train_loader = torch.utils.data.DataLoader(
             train_dataset,
             batch_size=batch_size,
             shuffle=True,
             num_workers=batch_size * 2)
         self.valid_loader = torch.utils.data.DataLoader(
             valid_dataset,
             batch_size=1,
             shuffle=False,
             num_workers=batch_size * 2)
     else:
         test_dataset = NiftyDataset(root_dir=root_dir,
                                     csv_file=test_csv,
                                     modal_num=modal_num,
                                     with_label=False,
                                     transform=transforms.Compose(
                                         self.transform_list))
         batch_size = 1
         self.test_loder = torch.utils.data.DataLoader(
             test_dataset,
             batch_size=batch_size,
             shuffle=False,
             num_workers=batch_size)
示例#2
0
    def get_stage_dataset_from_config(self, stage):
        assert (stage in ['train', 'valid', 'test'])
        root_dir = self.config['dataset']['root_dir']
        modal_num = self.config['dataset']['modal_num']

        transform_key = stage + '_transform'
        if (stage == "valid" and transform_key not in self.config['dataset']):
            transform_key = "train_transform"
        transform_names = self.config['dataset'][transform_key]

        self.transform_list = []
        if (transform_names is None or len(transform_names) == 0):
            data_transform = None
        else:
            transform_param = self.config['dataset']
            transform_param['task'] = 'segmentation'
            for name in transform_names:
                if (name not in self.transform_dict):
                    raise (ValueError("Undefined transform {0:}".format(name)))
                one_transform = self.transform_dict[name](transform_param)
                self.transform_list.append(one_transform)
            data_transform = transforms.Compose(self.transform_list)

        csv_file = self.config['dataset'].get(stage + '_csv', None)
        dataset = NiftyDataset(root_dir=root_dir,
                               csv_file=csv_file,
                               modal_num=modal_num,
                               with_label=not (stage == 'test'),
                               transform=data_transform)
        return dataset
示例#3
0
    def get_stage_dataset_from_config(self, stage):
        assert (stage in ['train', 'valid', 'test'])
        root_dir = self.config['dataset']['root_dir']
        modal_num = self.config['dataset']['modal_num']

        if (stage == "train" or stage == "valid"):
            transform_names = self.config['dataset']['train_transform']
            with_weight = self.config['dataset'].get('load_pixelwise_weight',
                                                     False)
        elif (stage == "test"):
            transform_names = self.config['dataset']['test_transform']
            with_weight = False
        else:
            raise ValueError("Incorrect value for stage: {0:}".format(stage))
        self.transform_list = []
        if (transform_names is None or len(transform_names) == 0):
            data_transform = None
        else:
            for name in transform_names:
                if (name not in self.transform_dict):
                    raise (ValueError("Undefined transform {0:}".format(name)))
                one_transform = self.transform_dict[name](
                    self.config['dataset'])
                self.transform_list.append(one_transform)
            data_transform = transforms.Compose(self.transform_list)

        csv_file = self.config['dataset'].get(stage + '_csv', None)
        dataset = NiftyDataset(root_dir=root_dir,
                               csv_file=csv_file,
                               modal_num=modal_num,
                               with_label=not (stage == 'test'),
                               with_weight=with_weight,
                               transform=data_transform)
        return dataset
示例#4
0
    def get_stage_dataset_from_config(self, stage):
        assert (stage in ['train', 'valid', 'test'])
        root_dir = self.config['dataset']['root_dir']
        modal_num = self.config['dataset']['modal_num']
        if (stage == "train" or stage == "valid"):
            transform_names = self.config['dataset']['train_transform']
        elif (stage == "test"):
            transform_names = self.config['dataset']['test_transform']
        else:
            raise ValueError("Incorrect value for stage: {0:}".format(stage))

        self.transform_list = [get_transform(name, self.config['dataset']) \
                    for name in transform_names ]
        csv_file = self.config['dataset'].get(stage + '_csv', None)
        dataset = NiftyDataset(root_dir=root_dir,
                               csv_file=csv_file,
                               modal_num=modal_num,
                               with_label=not (stage == 'test'),
                               transform=transforms.Compose(
                                   self.transform_list))
        return dataset
示例#5
0

if __name__ == "__main__":
    root_dir = '/home/guotai/data/brats/BraTS2018_Training'
    train_csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_train.csv'
    valid_csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_valid.csv'

    crop1 = CropWithBoundingBox(start=None, output_size=[4, 144, 176, 144])
    scale = Rescale(output_size=[96, 128, 96])
    norm = ChannelWiseNormalize(mean=None, std=None, zero_to_random=True)
    labconv = LabelConvert([0, 1, 2, 4], [0, 1, 2, 3])
    crop2 = RandomCrop([80, 80, 80])
    transform_list = [crop1, scale, norm, labconv, crop2]
    train_dataset = NiftyDataset(root_dir=root_dir,
                                 csv_file=train_csv_file,
                                 modal_num=4,
                                 with_label=True,
                                 transform=transforms.Compose(transform_list))
    valid_dataset = NiftyDataset(root_dir=root_dir,
                                 csv_file=valid_csv_file,
                                 modal_num=4,
                                 with_label=True,
                                 transform=transforms.Compose(transform_list))
    trainloader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=2,
                                              shuffle=True,
                                              num_workers=8)
    validloader = torch.utils.data.DataLoader(valid_dataset,
                                              batch_size=2,
                                              shuffle=True,
                                              num_workers=8)
示例#6
0
from pymic.io.nifty_dataset import NiftyDataset
from pymic.io.transform3d import *

if __name__ == "__main__":
    root_dir = '/home/guotai/data/brats/BraTS2018_Training'
    csv_file = '/home/guotai/projects/torch_brats/brats/config/brats18_train_train.csv'

    crop1 = CropWithBoundingBox(start=None, output_size=[4, 144, 180, 144])
    norm = ChannelWiseNormalize(mean=None, std=None, zero_to_random=True)
    labconv = LabelConvert([0, 1, 2, 4], [0, 1, 2, 3])
    crop2 = RandomCrop([128, 128, 128])
    rescale = Rescale([64, 64, 64])
    transform_list = [crop1, norm, labconv, crop2, rescale, ToTensor()]
    transformed_dataset = NiftyDataset(
        root_dir=root_dir,
        csv_file=csv_file,
        modal_num=4,
        transform=transforms.Compose(transform_list))
    dataloader = DataLoader(transformed_dataset,
                            batch_size=4,
                            shuffle=True,
                            num_workers=4)
    # Helper function to show a batch

    for i_batch, sample_batched in enumerate(dataloader):
        print(i_batch, sample_batched['image'].size(),
              sample_batched['label'].size())

        # # observe 4th batch and stop.
        modals = ['flair', 't1ce', 't1', 't2']
        if i_batch == 0: