예제 #1
0
 def __create_dataset(self):
     root_dir = self.config['dataset']['root_dir']
     train_csv = self.config['dataset'].get('train_csv', None)
     valid_csv = self.config['dataset'].get('valid_csv', None)
     test_csv = self.config['dataset'].get('test_csv', None)
     modal_num = self.config['dataset']['modal_num']
     if (self.stage == 'train'):
         transform_names = self.config['dataset']['train_transform']
         validtransform_names = self.config['dataset']['valid_transform']
         self.validtransform_list = [get_transform(name, self.config['dataset']) \
         for name in validtransform_names if name != 'RegionSwop']
     else:
         transform_names = self.config['dataset']['test_transform']
     self.transform_list = [get_transform(name, self.config['dataset']) \
         for name in transform_names if name != 'RegionSwop']
     if ('RegionSwop' in transform_names):
         self.region_swop = get_transform('RegionSwop',
                                          self.config['dataset'])
     else:
         self.region_swop = None
     if (self.stage == 'train'):
         train_dataset = NiftyDataset(root_dir=root_dir,
                                      csv_file=train_csv,
                                      modal_num=modal_num,
                                      with_label=True,
                                      transform=transforms.Compose(
                                          self.transform_list))
         valid_dataset = NiftyDataset(root_dir=root_dir,
                                      csv_file=valid_csv,
                                      modal_num=modal_num,
                                      with_label=True,
                                      transform=transforms.Compose(
                                          self.validtransform_list))
         batch_size = self.config['training']['batch_size']
         self.train_loader = torch.utils.data.DataLoader(
             train_dataset,
             batch_size=batch_size,
             shuffle=True,
             num_workers=batch_size * 2)
         self.valid_loader = torch.utils.data.DataLoader(
             valid_dataset,
             batch_size=1,
             shuffle=False,
             num_workers=batch_size * 2)
     else:
         test_dataset = NiftyDataset(root_dir=root_dir,
                                     csv_file=test_csv,
                                     modal_num=modal_num,
                                     with_label=False,
                                     transform=transforms.Compose(
                                         self.transform_list))
         batch_size = 1
         self.test_loder = torch.utils.data.DataLoader(
             test_dataset,
             batch_size=batch_size,
             shuffle=False,
             num_workers=batch_size)
예제 #2
0
    def get_stage_dataset_from_config(self, stage):
        assert (stage in ['train', 'valid', 'test'])
        root_dir = self.config['dataset']['root_dir']
        modal_num = self.config['dataset']['modal_num']
        if (stage == "train" or stage == "valid"):
            transform_names = self.config['dataset']['train_transform']
        elif (stage == "test"):
            transform_names = self.config['dataset']['test_transform']
        else:
            raise ValueError("Incorrect value for stage: {0:}".format(stage))

        self.transform_list = [get_transform(name, self.config['dataset']) \
                    for name in transform_names ]
        csv_file = self.config['dataset'].get(stage + '_csv', None)
        dataset = NiftyDataset(root_dir=root_dir,
                               csv_file=csv_file,
                               modal_num=modal_num,
                               with_label=not (stage == 'test'),
                               transform=transforms.Compose(
                                   self.transform_list))
        return dataset