def test_infinite_multi_threaded(self):
        data = list(range(123))
        num_workers = 3

        dl = DummyDataLoader(deepcopy(data),
                             12,
                             num_workers,
                             1,
                             return_incomplete=True,
                             shuffle=True,
                             infinite=False)
        mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)

        # this should raise a StopIteration
        with self.assertRaises(StopIteration):
            for i in range(1000):
                idx = next(mt)

        dl = DummyDataLoader(deepcopy(data),
                             12,
                             num_workers,
                             1,
                             return_incomplete=True,
                             shuffle=True,
                             infinite=True)
        mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
        # this should now not raise a StopIteration anymore
        for i in range(1000):
            idx = next(mt)
def get_generators(patch_size,
                   batch_size,
                   preprocess_func,
                   output_reshape_func,
                   num_validation,
                   train_processes,
                   train_cache,
                   train_data_dir='data/train/'):
    """
    Creates augmented batch loaders and generators for Keras for training and validation
    :param patch_size: input of network without batch_size dimension
    :param batch_size:
    :param preprocess_func: callable to preprocess data per sample
    :param output_reshape_func: callable to reshape preprocessed and augmented data per sample
    :param num_validation: number of samples to validate on
    :param train_processes: number of threads to load, preprocess and augment data
    :param  train_cache: number of augmented samples to cache
    """

    dirs = util.get_data_list(train_data_dir)
    labels = util.parse_labels_months()
    train_paths, validation_paths = util.train_validation_split(dirs, labels)
    # generate train batch loader
    train_data_loader = CTBatchLoader(train_paths,
                                      batch_size,
                                      patch_size,
                                      num_threads_in_multithreaded=1,
                                      preprocess_func=preprocess_func)

    train_transforms = get_train_transform(patch_size)
    train_data_generator = MultiThreadedAugmenter(
        train_data_loader,
        train_transforms,
        num_processes=train_processes,
        num_cached_per_queue=train_cache,
        seeds=None,
        pin_memory=False)

    # wrapper to be compatible with keras
    train_generator_keras = KerasGenerator(
        train_data_generator, output_reshapefunc=output_reshape_func)

    # generate validation batch loader
    valid_data_loader = CTBatchLoader(validation_paths,
                                      num_validation,
                                      patch_size,
                                      num_threads_in_multithreaded=1,
                                      preprocess_func=preprocess_func)
    valid_transforms = get_valid_transform(patch_size)
    valid_data_generator = MultiThreadedAugmenter(valid_data_loader,
                                                  valid_transforms,
                                                  num_processes=1,
                                                  num_cached_per_queue=1,
                                                  seeds=None,
                                                  pin_memory=False)
    # wrapper to be compatible with keras
    valid_generator_keras = KerasGenerator(valid_data_generator,
                                           output_reshape_func, 1)

    return train_generator_keras, valid_generator_keras
    def __init__(self,
                 base_dir,
                 mode="train",
                 batch_size=16,
                 num_batches=10000000,
                 seed=None,
                 num_processes=8,
                 num_cached_per_queue=8 * 4,
                 target_size=128,
                 file_pattern='*.png',
                 do_reshuffle=True,
                 keys=None):

        data_loader = MedImageDataLoader(base_dir=base_dir,
                                         mode=mode,
                                         batch_size=batch_size,
                                         num_batches=num_batches,
                                         seed=seed,
                                         file_pattern=file_pattern,
                                         keys=keys)

        self.data_loader = data_loader
        self.batch_size = batch_size
        #self.do_reshuffle = do_reshuffle
        self.number_of_slices = 1

        self.transforms = get_transforms(mode=mode, target_size=target_size)
        self.augmenter = MultiThreadedAugmenter(
            data_loader,
            self.transforms,
            num_processes=num_processes,
            num_cached_per_queue=num_cached_per_queue,
            seeds=seed,
            shuffle=do_reshuffle)
        self.augmenter.restart()
    def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
    def __next__(self):

        if self.n <= self.n_folds:
            # t0 = time.time()
            self._build_generator()
            # t1 = time.time()
            # print("\n\n\nDuration: {0} {1}\n\n\n".format(np.round(t1-t0,2), self.n))

            if self.train_bg_augmenter is None:
                self.train_bg_augmenter = MultiThreadedAugmenter(
                    self.train_batchgenerator,
                    self.image_transform,
                    self.num_workers,
                    pin_memory=True)
            else:
                self.train_bg_augmenter.set_generator(
                    self.train_batchgenerator)

            if self.valid_bg_augmenter is None:
                self.valid_bg_augmenter = MultiThreadedAugmenter(
                    self.valid_batchgenerator,
                    self.image_transform,
                    self.num_workers,
                    pin_memory=True)
            else:
                self.valid_bg_augmenter.set_generator(
                    self.valid_batchgenerator)

            self.n += 1

            return self.train_bg_augmenter, self.valid_bg_augmenter
        else:
            raise StopIteration
示例#6
0
    def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch):
        """
        Trains the network a single epoch

        Parameters
        ----------
        batchgen : MultiThreadedAugmenter
            Generator yielding the training batches
        epoch : int
            current epoch

        """
        self.module.training = True

        n_batches = batchgen.generator.num_batches * batchgen.num_processes
        pbar = tqdm(enumerate(batchgen),
                    unit=' batch',
                    total=n_batches,
                    desc='Epoch %d' % epoch)

        for batch_nr, batch in pbar:

            data_dict = batch

            _, _, _ = self.closure_fn(self.module,
                                      data_dict,
                                      optimizers=self.optimizers,
                                      losses=self.losses,
                                      metrics=self.metrics,
                                      fold=self.fold,
                                      batch_nr=batch_nr)

        batchgen._finish()
    def get_next_loader(self, train=True, build_new=False):

        if self.n <= self.n_folds:
            # t0 = time.time()
            if build_new:
                self._build_generator()
            # t1 = time.time()
            # print("\n\n\nDuration: {0} {1}\n\n\n".format(np.round(t1-t0,2), self.n))

            if self.bg_augmenter is None:
                if train:
                    self.bg_augmenter = MultiThreadedAugmenter(
                        self.train_batchgenerator,
                        self.image_transform,
                        self.num_workers,
                        pin_memory=True)
                else:
                    self.bg_augmenter = MultiThreadedAugmenter(
                        self.valid_batchgenerator,
                        self.image_transform,
                        self.num_workers,
                        pin_memory=True)
            else:
                if train:
                    self.bg_augmenter.set_generator(self.train_batchgenerator)
                else:
                    self.bg_augmenter.set_generator(self.valid_batchgenerator)

            if build_new:
                self.n += 1

            return self.bg_augmenter
        else:
            return None
    def test_return_incomplete_multi_threaded(self):
        data = list(range(123))
        batch_size = 12
        num_workers = 3

        dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=False, shuffle=False, infinite=False)
        mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
        # this should now not raise a StopIteration anymore
        total = 0
        ctr = 0
        for i in mt:
            ctr += 1
            assert len(i) == batch_size
            total += batch_size

        self.assertTrue(total == 120)
        self.assertTrue(ctr == 10)

        dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False)
        mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
        # this should now not raise a StopIteration anymore
        total = 0
        ctr = 0
        for i in mt:
            ctr += 1
            total += len(i)

        self.assertTrue(total == 123)
        self.assertTrue(ctr == 11)
示例#9
0
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes,
                                  num_workers=5, num_cached_per_worker=2,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                       PATCH_SIZE=(352, 352))

    tr_transforms = []
    tr_transforms.append(Mirror((2, 3)))
    tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2),
                                                       do_elastic_transform, alpha,
                                                       sigma,
                                                       do_rotation, a_x, a_y,
                                                       a_z,
                                                       do_scale, scale_range, 'constant', 0, 3, 'constant',
                                                       0, 0,
                                                       random_crop=False), prob=0.67,
                                      alternative_transform=RandomCropTransform((352, 352))))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot'))

    tr_composed = Compose(tr_transforms)
    tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds)
    tr_mt_gen.restart()
    return tr_mt_gen
示例#10
0
    def test_no_crash(self):
        """
        This one should just not crash, that's all
        :return:
        """
        dl = self.dl_images
        mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)

        for _ in range(20):
            _ = mt_dl.next()
示例#11
0
    def get_batchgen(self, seed=1):
        """
        Create DataLoader and Batchgenerator

        Parameters
        ----------
        seed : int
            seed for Random Number Generator

        Returns
        -------
        MultiThreadedAugmenter
            Batchgenerator

        Raises
        ------
        AssertionError
            :attr:`BaseDataManager.n_batches` is smaller than or equal to zero

        """
        assert self.n_batches > 0

        data_loader = self.data_loader_cls(self.dataset,
                                           batch_size=self.batch_size,
                                           num_batches=self.n_batches,
                                           seed=seed,
                                           sampler=self.sampler)

        return MultiThreadedAugmenter(data_loader,
                                      self.transforms,
                                      self.n_process_augmentation,
                                      num_cached_per_queue=2,
                                      seeds=self.n_process_augmentation *
                                      [seed])
    def test_return_all_indices_multi_threaded_shuffle_True(self):
        data = list(range(123))
        batch_sizes = [1, 3, 75, 12, 23]
        num_workers = 3

        for b in batch_sizes:
            dl = DummyDataLoader(deepcopy(data),
                                 b,
                                 num_workers,
                                 1,
                                 return_incomplete=True,
                                 shuffle=True,
                                 infinite=False)
            mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)

            for _ in range(3):
                idx = []
                for i in mt:
                    idx += i

                assert len(idx) == len(data)

                assert not all([i == j for i, j in zip(idx, data)])

                idx.sort()
                assert all([i == j for i, j in zip(idx, data)])
示例#13
0
    def load_dataset(self):
        train = get_file_list(self.args.data_path, self.args.train_ids_path)
        #val = get_file_list(self.args.data_path,self.args.valid_ids_path)

        train, val = get_split_deterministic(train,
                                             fold=0,
                                             num_splits=5,
                                             random_state=12345)

        shapes = [brats_dataloader.load_patient(i)[0].shape[1:] for i in train]
        max_shape = np.max(shapes, 0)
        max_shape = list(np.max((max_shape, self.args.patch_size), 0))

        dataloader_train = brats_dataloader(train,
                                            self.args.batch_size,
                                            max_shape,
                                            self.args.num_threads,
                                            return_incomplete=True,
                                            infinite=False)

        tr_transforms = get_train_transform(self.args.patch_size)

        tr_gen = MultiThreadedAugmenter(dataloader_train,
                                        tr_transforms,
                                        num_processes=self.args.num_threads,
                                        num_cached_per_queue=3,
                                        seeds=None,
                                        pin_memory=False)

        self.num_batches_per_epoch = int(
            math.ceil(len(train) / self.args.batch_size))
        self.train_data_loader = tr_gen
        self.val = val
def get_train_val_generators(fold):
    tr_keys, te_keys = get_split(fold, split_seed)
    train_data = {i: dataset[i] for i in tr_keys}
    val_data = {i: dataset[i] for i in te_keys}

    data_gen_train = create_data_gen_train(
        train_data,
        BATCH_SIZE,
        num_classes,
        INPUT_PATCH_SIZE,
        num_workers=num_workers,
        do_elastic_transform=True,
        alpha=(0., 350.),
        sigma=(14., 17.),
        do_rotation=True,
        a_x=(0, 2. * np.pi),
        a_y=(-0.000001, 0.00001),
        a_z=(-0.000001, 0.00001),
        do_scale=True,
        scale_range=(0.7, 1.3),
        seeds=workers_seeds)  # new se has no brain mask

    data_gen_validation = BatchGenerator(val_data,
                                         BATCH_SIZE,
                                         num_batches=None,
                                         seed=False,
                                         PATCH_SIZE=INPUT_PATCH_SIZE)
    val_transforms = []
    val_transforms.append(
        ConvertSegToOnehotTransform(range(4), 0, 'seg_onehot'))
    data_gen_validation = MultiThreadedAugmenter(data_gen_validation,
                                                 Compose(val_transforms), 1, 2,
                                                 [0])
    return data_gen_train, data_gen_validation
示例#15
0
    def test_return_incomplete_multi_threaded(self):
        data = list(range(123))
        batch_size = 12
        num_workers = 3

        dl = DummyDataLoader(deepcopy(data),
                             batch_size,
                             num_workers,
                             1,
                             return_incomplete=False,
                             shuffle=False,
                             infinite=False)
        mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
        all_return = []
        total = 0
        ctr = 0
        for i in mt:
            ctr += 1
            assert len(i) == batch_size
            total += len(i)
            all_return += i

        self.assertTrue(total == 120)
        self.assertTrue(ctr == 10)
        self.assertTrue(len(np.unique(all_return)) == total)

        dl = DummyDataLoader(deepcopy(data),
                             batch_size,
                             num_workers,
                             1,
                             return_incomplete=True,
                             shuffle=False,
                             infinite=False)
        mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False)
        all_return = []
        total = 0
        ctr = 0
        for i in mt:
            ctr += 1
            total += len(i)
            all_return += i

        self.assertTrue(total == 123)
        self.assertTrue(ctr == 11)
        self.assertTrue(len(np.unique(all_return)) == len(data))
示例#16
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''
        from batchgenerators.transforms import MirrorTransform, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
示例#17
0
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5),
                          gamma_range = (0.6, 2),
                                  num_workers=5, num_cached_per_worker=3,
                                  do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.),
                                  do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                  do_scale=True, scale_range=(0.75, 1.25), seeds=None):
    if seeds is None:
        seeds = [None]*num_workers
    elif seeds == 'range':
        seeds = range(num_workers)
    else:
        assert len(seeds) == num_workers
    data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False,
                                                          patch_size=(160, 192, 160), convert_labels=True)
    tr_transforms = []
    tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3]))
    tr_transforms.append(GenerateBrainMaskTransform())
    tr_transforms.append(MirrorTransform())
    tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.),
                                       do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma,
                                       do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z,
                                       do_scale=do_scale, scale=scale_range, border_mode_data='nearest',
                                       border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0,
                                       order_seg=0, random_crop=True))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True))
    tr_transforms.append(ContrastAugmentationTransform(contrast_range, True))
    tr_transforms.append(GammaTransform(gamma_range, False))
    tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True))
    tr_transforms.append(BrightnessTransform(0.0, 0.1, True))
    tr_transforms.append(SegChannelSelectionTransform([0]))
    tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot"))

    gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker,
                                       seeds)
    gen_train.restart()
    return gen_train
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    :param dataloader_train:
    :param dataloader_val:
    :param patch_size:
    :param params:
    :param border_val_seg:
    :return:
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"),
                                                  seeds=range(params.get('num_threads')), pin_memory=True)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1),
                                                params.get("num_cached_per_thread"),
                                                seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
    def test_order(self):
        """
        Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded
        environment to still give us the numbers from 0 to 99 in ascending order
        :return:
        """
        dl = self.dl
        mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)

        res = []
        for i in mt:
            res.append(i)

        assert len(res) == 100
        res_copy = deepcopy(res)
        res.sort()
        assert all((i == j for i, j in zip(res, res_copy)))
        assert all((i == j for i, j in zip(res, np.arange(0, 100))))
def get_data_augmenter(data, batch_size=1,
                       mode=DataLoader.Mode.NORMAL,
                       volumetric=True,
                       normalization_range=None,
                       vector_generator=None,
                       input_shape=None,
                       sample_count=1,
                       transforms=None,
                       threads=1,
                       seed=None):
    transforms = [] if transforms is None else transforms
    threads = min(int(np.ceil(len(data) / batch_size)), threads)
    loader = DataLoader(data=data,
                        batch_size=batch_size,
                        mode=mode,
                        volumetric=volumetric,
                        normalization_range=normalization_range,
                        vector_generator=vector_generator,
                        input_shape=input_shape,
                        sample_count=sample_count,
                        number_of_threads_in_multithreaded=threads,
                        seed=seed)
    transforms = transforms + [PrepareForTF()]
    return MultiThreadedAugmenter(loader, Compose(transforms), threads)
示例#21
0
def main():
    # --------- Parse arguments ---------------------------------------------------------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='./config_0.yaml',
                        help='Path to the configuration file.')
    parser.add_argument('--dice', action='store_true')
    # be aware of this argument!!!
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('-i', '--inference', default='', type=str, metavar='PATH',
                        help='run inference on data set and save results')

    # 1e-8 works well for lung masks but seems to prevent
    # rapid learning for nodule masks
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--save')
    parser.add_argument('--seed', type=int, default=1)
    args = parser.parse_args()
    # ---------- get the config file(config.yaml) --------------------------------------------
    config = get_config(args.config)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = os.path.join('./work', (datestr() + ' ' + config['filename']))
    nll = True
    if config['dice']:
        nll = False

    weight_decay = config['weight_decay']
    num_threads_for_kits19 = config['num_of_threads']
    patch_size = (160, 160, 128)
    num_batch_per_epoch = config['num_batch_per_epoch']
    setproctitle.setproctitle(args.save)
    start_epoch = 1
    # -------- Record best kidney segmentation dice -------------------------------------------
    best_tk = 0.0
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    print("build vnet")
    # Embed attention module
    model = vnet.VNet(elu=False, nll=nll,
                      attention=config['attention'], nclass=3)  # mark
    batch_size = config['ngpu'] * config['batchSz']
    save_iter = config['model_save_iter']
    # batch_size = args.ngpu*args.batchSz
    gpu_ids = range(config['ngpu'])
    # print(gpu_ids)
    model.apply(weights_init)
    # ------- Resume training from saved model -----------------------------------------------
    if config['resume']:
        if os.path.isfile(config['resume']):
            print("=> loading checkpoint '{}'".format(config['resume']))
            checkpoint = torch.load(config['resume'])
            # .tar files
            if config['resume'].endswith('.tar'):
                # print(checkpoint, "tar")
                start_epoch = checkpoint['epoch']
                best_tk = checkpoint['best_tk']
                checkpoint_model = checkpoint['model_state_dict']
                model.load_state_dict(
                    {k.replace('module.', ''): v for k, v in checkpoint_model.items()})
            # .pkl files for the whole model
            else:
                # print(checkpoint, "pkl")
                model.load_state_dict(checkpoint.state_dict())
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config['resume']))
            exit(-1)
    else:
        pass
    # ------- Which loss function to use ------------------------------------------------------
    if nll:
        training = train_bg
        validate = test_bg
        # class_balance = True
    else:
        training = train_bg_dice
        validate = test_bg_dice
        # class_balance = False
    # -----------------------------------------------------------------------------------------
    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    # -------- Set on GPU ---------------------------------------------------------------------
    if args.cuda:
        model = model.cuda()

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    # create the output directory
    os.makedirs(args.save)
    # save the config file to the output folder
    shutil.copy(args.config, os.path.join(args.save, 'config.yaml'))

    # kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    # ------ Load Training and Validation set --------------------------------------------
    preprocessed_folders = "/home/data_share/npy_data/"
    patients = get_list_of_patients(
        preprocessed_data_folder=preprocessed_folders)
    # split num_split cross-validation sets
    # train, val = get_split_deterministic(
    #     patients, fold=0, num_splits=5, random_state=12345)
    train, val = patients[0:147], patients[147:189]

    # VALIDATION DATA CANNOT BE LOADED IN CASE DUE TO THE LARGE SHAPE...
    # PRINT VALIDATION CASES FOR LATER TEST USE!!
    print("Validation cases:\n", val)
    # set max shape for validation set 
    shapes = [Kits2019DataLoader3D.load_patient(
        i)[0].shape[1:] for i in val]
    max_shape = np.max(shapes, 0)
    max_shape = np.max((max_shape, patch_size), 0)
    # data loading + augmentation
    dataloader_train = Kits2019DataLoader3D(
        train, batch_size, patch_size, num_threads_for_kits19)
    dataloader_validation = Kits2019DataLoader3D(
        val, batch_size * 2, patch_size, num_threads_for_kits19)
    tr_transforms = get_train_transform(patch_size, prob=config['prob'])
    # whether to use single/multiThreadedAugmenter ------------------------------------------
    if num_threads_for_kits19 > 1:
        tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, 
                                        num_processes=num_threads_for_kits19,
                                        num_cached_per_queue=3,seeds=None, pin_memory=True)
        val_gen = MultiThreadedAugmenter(dataloader_validation, None,
                                         num_processes=max(1, num_threads_for_kits19//2), 
                                         num_cached_per_queue=1, seeds=None, pin_memory=False)
        
        tr_gen.restart()
        val_gen.restart()
    else:
        tr_gen = SingleThreadedAugmenter(dataloader_train, transform=tr_transforms)
        val_gen = SingleThreadedAugmenter(dataloader_validation, transform=None)
    # ------- Set learning rate scheduler ----------------------------------------------------
    lr_schdl = lr_scheduler.LR_Scheduler(mode=config['lr_policy'], base_lr=config['lr'],
                                         num_epochs=config['nEpochs'], iters_per_epoch=num_batch_per_epoch,
                                         lr_step=config['step_size'], warmup_epochs=config['warmup_epochs'])
    
    # ------ Choose Optimizer ----------------------------------------------------------------
    if config['opt'] == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=config['lr'],
                              momentum=0.99, weight_decay=weight_decay)
    elif config['opt'] == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=config['lr'], weight_decay=weight_decay)
    elif config['opt'] == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=config['lr'], weight_decay=weight_decay)
    lr_plateu = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, verbose=True, threshold=1e-3, patience=5)
    # ------- Apex Mixed Precision Acceleration ----------------------------------------------
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    model = nn.parallel.DataParallel(model, device_ids=gpu_ids)
    # ------- Save training data -------------------------------------------------------------
    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    trainF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')
    testF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n ')
    # ------- Training Pipeline --------------------------------------------------------------
    for epoch in range(start_epoch, config['nEpochs'] + start_epoch):
        torch.cuda.empty_cache()
        training(args, epoch, model, tr_gen, optimizer, trainF, config, lr_schdl)
        torch.cuda.empty_cache()
        print('==>lr decay to:', optimizer.param_groups[0]['lr'])
        print('testing validation set...')
        composite_dice = validate(args, epoch, model, val_gen, optimizer, testF, config, lr_plateu)
        torch.cuda.empty_cache()
        # save model with best result and routinely
        if composite_dice > best_tk or epoch % config['model_save_iter'] == 0:
            # model_name = 'vnet_epoch_step1_' + str(epoch) + '.pkl'
            model_name = 'vnet_step1_' + str(epoch) + '.tar'
            # torch.save(model, os.path.join(args.save, model_name))
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_tk': best_tk
            }, os.path.join(args.save, model_name))
            best_tk = composite_dice
    # ----------------------------------------------------------------------------------------
    trainF.close()
    testF.close()
def get_no_augmentation(dataloader_train,
                        dataloader_val,
                        params=default_3D_augmentation_params,
                        deep_supervision_scales=None,
                        soft_ds=False,
                        classes=None,
                        pin_memory=True,
                        regions=None):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=range(params.get('num_threads')),
        pin_memory=pin_memory)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=range(max(params.get('num_threads') // 2, 1)),
        pin_memory=pin_memory)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
示例#23
0
                             scale=params.get("scale_range"),
                             border_mode_data=params.get("border_mode_data"),
                             border_cval_data=0,
                             order_data=order_data,
                             border_mode_seg="constant",
                             border_cval_seg=border_val_seg,
                             order_seg=order_seg,
                             random_crop=params.get("random_crop"),
                             p_el_per_sample=params.get("p_eldef"),
                             p_scale_per_sample=params.get("p_scale"),
                             p_rot_per_sample=params.get("p_rot"),
                             independent_scale_for_each_axis=params.get(
                                 "independent_scale_factor_for_each_axis")))
        tr_transforms = Compose(tr_transforms)
        batchgenerator_train = MultiThreadedAugmenter(
            dataloader_train,
            tr_transforms,
            params.get('num_threads'),
            params.get("num_cached_per_thread"),
            pin_memory=True)
        train_loader = batchgenerator_train
        train_batch = next(train_loader)
        print(
            train_batch.keys()
        )  # dict_keys(['data', 'target']), each with torch.Size([2, 1, 112, 240, 272])
        print((train_batch['seg'] - train_batch['weak_label']).sum())
        train_batch = next(train_loader)
        print((train_batch['seg'] - train_batch['weak_label']).sum())
        ipdb.set_trace()

        print(sum(1 for _ in train_loader))
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
                             border_val_seg=-1, pin_memory=True,
                             seeds_train=None, seeds_val=None, regions=None):
    assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(SpatialTransform(
        patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
        alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
        do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
        angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
        border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
        border_cval_seg=border_val_seg,
        order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
        p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
        independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
    ))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
        tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
        tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations") and not None and params.get(
                "cascade_do_cascade_augmentations"):
            tr_transforms.append(ApplyRandomBinaryOperatorTransform(
                channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
                p_per_sample=params.get("cascade_random_binary_transform_p"),
                key="data",
                strel_size=params.get("cascade_random_binary_transform_size")))
            tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
                key="data",
                p_per_sample=params.get("cascade_remove_conn_comp_p"),
                fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
                dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)
    # from batchgenerators.dataloading import SingleThreadedAugmenter
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"), seeds=seeds_train,
                                                  pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
        val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
                                                params.get("num_cached_per_thread"), seeds=seeds_val,
                                                pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
示例#25
0
def run(fold=0):
    print fold
    # =================================================================================================================
    I_AM_FOLD = fold
    np.random.seed(65432)
    lasagne.random.set_rng(np.random.RandomState(98765))
    sys.setrecursionlimit(2000)
    BATCH_SIZE = 2
    INPUT_PATCH_SIZE =(128, 128, 128)
    num_classes=4

    EXPERIMENT_NAME = "final"
    results_dir = os.path.join(paths.results_folder)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)
    results_dir = os.path.join(results_dir, EXPERIMENT_NAME)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)
    results_dir = os.path.join(results_dir, "fold%d"%I_AM_FOLD)
    if not os.path.isdir(results_dir):
        os.mkdir(results_dir)

    n_epochs = 300
    lr_decay = np.float32(0.985)
    base_lr = np.float32(0.0005)
    n_batches_per_epoch = 100
    n_test_batches = 10
    n_feedbacks_per_epoch = 10.
    num_workers = 6
    workers_seeds = [123, 1234, 12345, 123456, 1234567, 12345678]

    # =================================================================================================================

    all_data = load_dataset()
    keys_sorted = np.sort(all_data.keys())

    crossval_folds = KFold(len(all_data.keys()), n_folds=5, shuffle=True, random_state=123456)

    ctr = 0
    for train_idx, test_idx in crossval_folds:
        print len(train_idx), len(test_idx)
        if ctr == I_AM_FOLD:
            train_keys = [keys_sorted[i] for i in train_idx]
            test_keys = [keys_sorted[i] for i in test_idx]
            break
        ctr += 1

    train_data = {i:all_data[i] for i in train_keys}
    test_data = {i:all_data[i] for i in test_keys}

    data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE,
                                           contrast_range=(0.75, 1.5), gamma_range = (0.8, 1.5),
                                           num_workers=num_workers, num_cached_per_worker=2,
                                           do_elastic_transform=True, alpha=(0., 1300.), sigma=(10., 13.),
                                           do_rotation=True, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi),
                                           do_scale=True, scale_range=(0.75, 1.25), seeds=workers_seeds)

    data_gen_validation = BatchGenerator3D_random_sampling(test_data, BATCH_SIZE, num_batches=None, seed=False,
                                                           patch_size=INPUT_PATCH_SIZE, convert_labels=True)
    val_transforms = []
    val_transforms.append(GenerateBrainMaskTransform())
    val_transforms.append(BrainMaskAwareStretchZeroOneTransform(clip_range=(-5, 5), per_channel=True))
    val_transforms.append(SegChannelSelectionTransform([0]))
    val_transforms.append(ConvertSegToOnehotTransform(range(4), 0, "seg_onehot"))
    val_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3]))
    data_gen_validation = MultiThreadedAugmenter(data_gen_validation, Compose(val_transforms), 2, 2)

    x_sym = T.tensor5()
    seg_sym = T.matrix()

    net, seg_layer = build_net(x_sym, INPUT_PATCH_SIZE, num_classes, 4, 16, batch_size=BATCH_SIZE,
                               do_instance_norm=True)
    output_layer_for_loss = net

    # add some weight decay
    l2_loss = lasagne.regularization.regularize_network_params(output_layer_for_loss, lasagne.regularization.l2) * 1e-5

    # the distinction between prediction_train and test is important only if we enable dropout (batch norm/inst norm
    # does not use or save moving averages)
    prediction_train = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=False,
                                                 batch_norm_update_averages=False, batch_norm_use_averages=False)

    loss_vec = - soft_dice_per_img_in_batch(prediction_train, seg_sym, BATCH_SIZE)[:, 1:]

    loss = loss_vec.mean()
    loss += l2_loss
    acc_train = T.mean(T.eq(T.argmax(prediction_train, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX)

    prediction_test = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=True,
                                                batch_norm_update_averages=False, batch_norm_use_averages=False)
    loss_val = - soft_dice_per_img_in_batch(prediction_test, seg_sym, BATCH_SIZE)[:, 1:]

    loss_val = loss_val.mean()
    loss_val += l2_loss
    acc = T.mean(T.eq(T.argmax(prediction_test, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX)

    # learning rate has to be a shared variable because we decrease it with every epoch
    params = lasagne.layers.get_all_params(output_layer_for_loss, trainable=True)
    learning_rate = theano.shared(base_lr)
    updates = lasagne.updates.adam(T.grad(loss, params), params, learning_rate=learning_rate, beta1=0.9, beta2=0.999)

    dc = hard_dice_per_img_in_batch(prediction_test, seg_sym.argmax(1), num_classes, BATCH_SIZE).mean(0)

    train_fn = theano.function([x_sym, seg_sym], [loss, acc_train, loss_vec], updates=updates)
    val_fn = theano.function([x_sym, seg_sym], [loss_val, acc, dc])

    all_val_dice_scores=None

    all_training_losses = []
    all_validation_losses = []
    all_validation_accuracies = []
    all_training_accuracies = []
    val_dice_scores = []
    epoch = 0

    while epoch < n_epochs:
        if epoch == 100:
            data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE,
                                                   contrast_range=(0.85, 1.25), gamma_range = (0.8, 1.5),
                                                   num_workers=6, num_cached_per_worker=2,
                                                   do_elastic_transform=True, alpha=(0., 1000.), sigma=(10., 13.),
                                                   do_rotation=True, a_x=(0., 2*np.pi), a_y=(-np.pi/8., np.pi/8.),
                                                   a_z=(-np.pi/8., np.pi/8.), do_scale=True, scale_range=(0.85, 1.15),
                                                   seeds=workers_seeds)

        if epoch == 175:
            data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE,
                                                   contrast_range=(0.9, 1.1), gamma_range = (0.85, 1.3),
                                                   num_workers=6, num_cached_per_worker=2,
                                                   do_elastic_transform=True, alpha=(0., 750.), sigma=(10., 13.),
                                                   do_rotation=True, a_x=(0., 2*np.pi), a_y=(-0.00001, 0.00001),
                                                   a_z=(-0.00001, 0.00001), do_scale=True, scale_range=(0.85, 1.15),
                                                   seeds=workers_seeds)

        epoch_start_time = time.time()
        learning_rate.set_value(np.float32(base_lr* lr_decay**(epoch)))
        print "epoch: ", epoch, " learning rate: ", learning_rate.get_value()
        train_loss = 0
        train_acc_tmp = 0
        train_loss_tmp = 0
        batch_ctr = 0
        for data_dict in data_gen_train:
            data = data_dict["data"].astype(np.float32)
            seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes))
            if batch_ctr != 0 and batch_ctr % int(np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) == 0:
                print "number of batches: ", batch_ctr, "/", n_batches_per_epoch
                print "training_loss since last update: ", \
                    train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch), " train accuracy: ", \
                    train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)
                all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
                all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
                train_loss_tmp = 0
                train_acc_tmp = 0
                if len(val_dice_scores) > 0:
                    all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes))
                try:
                    printLosses(all_training_losses, all_training_accuracies, all_validation_losses,
                                all_validation_accuracies, os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME),
                                n_feedbacks_per_epoch, val_dice_scores=all_val_dice_scores,
                                val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"])
                except:
                    pass
            loss_vec, acc, l = train_fn(data, seg)

            loss = loss_vec.mean()
            train_loss += loss
            train_loss_tmp += loss
            train_acc_tmp += acc
            batch_ctr += 1
            if batch_ctr >= n_batches_per_epoch:
                break
        all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
        all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch))
        train_loss /= n_batches_per_epoch
        print "training loss average on epoch: ", train_loss

        val_loss = 0
        accuracies = []
        valid_batch_ctr = 0
        all_dice = []
        for data_dict in data_gen_validation:
            data = data_dict["data"].astype(np.float32)
            seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes))
            w = np.zeros(num_classes, dtype=np.float32)
            w[np.unique(seg.argmax(-1))] = 1
            loss, acc, dice = val_fn(data, seg)
            dice[w==0] = 2
            all_dice.append(dice)
            val_loss += loss
            accuracies.append(acc)
            valid_batch_ctr += 1
            if valid_batch_ctr >= n_test_batches:
                break
        all_dice = np.vstack(all_dice)
        dice_means = np.zeros(num_classes)
        for i in range(num_classes):
            dice_means[i] = all_dice[all_dice[:, i]!=2, i].mean()
        val_loss /= n_test_batches
        print "val loss: ", val_loss
        print "val acc: ", np.mean(accuracies), "\n"
        print "val dice: ", dice_means
        print "This epoch took %f sec" % (time.time()-epoch_start_time)
        val_dice_scores.append(dice_means)
        all_validation_losses.append(val_loss)
        all_validation_accuracies.append(np.mean(accuracies))
        all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes))
        try:
            printLosses(all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies,
                        os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME), n_feedbacks_per_epoch,
                        val_dice_scores=all_val_dice_scores, val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"])
        except:
            pass
        with open(os.path.join(results_dir, "%s_Params.pkl" % (EXPERIMENT_NAME)), 'w') as f:
            cPickle.dump(lasagne.layers.get_all_param_values(output_layer_for_loss), f)
        with open(os.path.join(results_dir, "%s_allLossesNAccur.pkl"% (EXPERIMENT_NAME)), 'w') as f:
            cPickle.dump([all_training_losses, all_training_accuracies, all_validation_losses,
                          all_validation_accuracies, val_dice_scores], f)
        epoch += 1
示例#26
0
    dataloader_train = BraTS2017DataLoader3D(train, batch_size, max_shape, 1)

    # during training I like to run a validation from time to time to see where I am standing. This is not a correct
    # validation because just like training this is patch-based but it's good enough. We don't do augmentation for the
    # validation, so patch_size is used as shape target here
    dataloader_validation = BraTS2017DataLoader3D(val, batch_size, patch_size,
                                                  1)

    tr_transforms = get_train_transform(patch_size)

    # finally we can create multithreaded transforms that we can actually use for training
    # we don't pin memory here because this is pytorch specific.
    tr_gen = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        num_processes=num_threads_for_brats_example,
        num_cached_per_queue=3,
        seeds=None,
        pin_memory=False)
    # we need less processes for vlaidation because we dont apply transformations
    val_gen = MultiThreadedAugmenter(dataloader_validation,
                                     None,
                                     num_processes=max(
                                         1,
                                         num_threads_for_brats_example // 2),
                                     num_cached_per_queue=1,
                                     seeds=None,
                                     pin_memory=False)

    # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training
    # batches while other things run in the main thread
示例#27
0
                                             shuffle=False,
                                             return_incomplete=True)
    dataloader_validation = BraTS2017DataLoader3D(val,
                                                  batch_size,
                                                  None,
                                                  1,
                                                  infinite=False,
                                                  shuffle=False,
                                                  return_incomplete=True)

    tr_transforms = get_train_transform(patch_size)

    tr_gen = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        num_processes=num_threads_for_brats_example,
        num_cached_per_queue=3,
        seeds=None,
        pin_memory=False)

    #tr_gen.restart()

    num_batches_per_epoch = 100
    num_validation_batches_per_epoch = 20
    num_epochs = 5
    # let's run this to get a time on how long it takes
    time_per_epoch = []
    start = time()
    for epoch in range(num_epochs):

        start_epoch = time()
示例#28
0
val_dl = BRATSDataLoader(patients_val,
                         batch_size=batch_size,
                         patch_size=patch_size,
                         in_channels=in_channels)

#%%
tr_transforms = get_train_transform(patch_size)

#%%
# finally we can create multithreaded transforms that we can actually use for training
# we don't pin memory here because this is pytorch specific.
tr_gen = MultiThreadedAugmenter(
    train_dl,
    tr_transforms,
    num_processes=4,  # num_processes=4
    num_cached_per_queue=3,
    seeds=None,
    pin_memory=False)
# we need less processes for vlaidation because we dont apply transformations
val_gen = MultiThreadedAugmenter(
    val_dl,
    None,
    num_processes=max(1, 4 // 2),  # num_processes=max(1, 4 // 2)
    num_cached_per_queue=1,
    seeds=None,
    pin_memory=False)

#%%
tr_gen.restart()
val_gen.restart()
def get_default_augmentation_withEDT(dataloader_train,
                                     dataloader_val,
                                     patch_size,
                                     idx_of_edts,
                                     params=default_3D_augmentation_params,
                                     border_val_seg=-1,
                                     pin_memory=True,
                                     seeds_train=None,
                                     seeds_val=None):
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key so that it does not get intensity transformed
    ##############################################################
    ##############################################################
    """
    tr_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key
    ##############################################################
    ##############################################################
    """
    val_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
示例#30
0
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val