def __create_dataset(self): root_dir = self.config['dataset']['root_dir'] train_csv = self.config['dataset'].get('train_csv', None) valid_csv = self.config['dataset'].get('valid_csv', None) test_csv = self.config['dataset'].get('test_csv', None) modal_num = self.config['dataset']['modal_num'] if (self.stage == 'train'): transform_names = self.config['dataset']['train_transform'] validtransform_names = self.config['dataset']['valid_transform'] self.validtransform_list = [get_transform(name, self.config['dataset']) \ for name in validtransform_names if name != 'RegionSwop'] else: transform_names = self.config['dataset']['test_transform'] self.transform_list = [get_transform(name, self.config['dataset']) \ for name in transform_names if name != 'RegionSwop'] if ('RegionSwop' in transform_names): self.region_swop = get_transform('RegionSwop', self.config['dataset']) else: self.region_swop = None if (self.stage == 'train'): train_dataset = NiftyDataset(root_dir=root_dir, csv_file=train_csv, modal_num=modal_num, with_label=True, transform=transforms.Compose( self.transform_list)) valid_dataset = NiftyDataset(root_dir=root_dir, csv_file=valid_csv, modal_num=modal_num, with_label=True, transform=transforms.Compose( self.validtransform_list)) batch_size = self.config['training']['batch_size'] self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size * 2) self.valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=1, shuffle=False, num_workers=batch_size * 2) else: test_dataset = NiftyDataset(root_dir=root_dir, csv_file=test_csv, modal_num=modal_num, with_label=False, transform=transforms.Compose( self.transform_list)) batch_size = 1 self.test_loder = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size)
def get_stage_dataset_from_config(self, stage): assert (stage in ['train', 'valid', 'test']) root_dir = self.config['dataset']['root_dir'] modal_num = self.config['dataset']['modal_num'] if (stage == "train" or stage == "valid"): transform_names = self.config['dataset']['train_transform'] elif (stage == "test"): transform_names = self.config['dataset']['test_transform'] else: raise ValueError("Incorrect value for stage: {0:}".format(stage)) self.transform_list = [get_transform(name, self.config['dataset']) \ for name in transform_names ] csv_file = self.config['dataset'].get(stage + '_csv', None) dataset = NiftyDataset(root_dir=root_dir, csv_file=csv_file, modal_num=modal_num, with_label=not (stage == 'test'), transform=transforms.Compose( self.transform_list)) return dataset