Example #1
0
def pipeline(source,
             iters=1,
             transformers=(),
             batch_size=1,
             combiner=combine_to_arrays):
    return Infinite(source,
                    *transformers,
                    batch_size=batch_size,
                    batches_per_epoch=iters,
                    buffer_size=1,
                    combiner=combiner)
Example #2
0
PATCH_SIZE = np.array([64, 64, 32])

# dataset
raw_dataset = BraTS2013(BRATS_PATH)
dataset = apply(CropToBrain(raw_dataset), load_image=partial(min_max_scale, axes=0))
train_dataset = cache_methods(ChangeSliceSpacing(dataset, new_slice_spacing=CONFIG['source_slice_spacing']))

# cross validation
split = load_json(SPLIT_PATH)
train_ids, val_ids, test_ids = split[int(FOLD)]

# batch iterator
batch_iter = Infinite(
    load_by_random_id(train_dataset.load_image, train_dataset.load_gt, ids=train_ids),
    unpack_args(tumor_sampling, patch_size=PATCH_SIZE, tumor_p=.5),
    random_apply(.5, unpack_args(lambda image, gt: (np.flip(image, 1), np.flip(gt, 0)))),
    batch_size=CONFIG['batch_size'], batches_per_epoch=CONFIG['batches_per_epoch']
)

# model
model = nn.Sequential(
    nn.Conv3d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(
        layers.ResBlock3d, downsample=nn.MaxPool3d(2, ceil_mode=True), upsample=nn.Identity,
        merge=lambda left, down: torch.cat(layers.interpolate_to_left(left, down, 'trilinear'), dim=1),
        structure=[
            [[8, 8, 8], [16, 8, 8]],
            [[8, 16, 16], [32, 16, 8]],
            [[16, 32, 32], [64, 32, 16]],
            [[32, 64, 64], [128, 64, 32]],
            [[64, 128, 128], [256, 128, 64]],
Example #3
0
# dataset
raw_dataset = BinaryGT(BraTS2013(BRATS_PATH), positive_classes=CONFIG['positive_classes'])
dataset = cache_methods(ZooOfSpacings(
    apply(CropToBrain(raw_dataset), load_image=partial(min_max_scale, axes=0)),
    slice_spacings=CONFIG['slice_spacings']
))

# cross validation
split = load_json(SPLIT_PATH)
train_ids, val_ids, test_ids = split[int(FOLD)]

# batch iterator
batch_iter = Infinite(
    load_by_random_id(dataset.load_image, dataset.load_gt, ids=train_ids),
    unpack_args(get_random_slice),
    random_apply(.5, unpack_args(lambda image, gt: (np.flip(image, 1), np.flip(gt, 0)))),
    apply_at(1, prepend_dims),
    apply_at(1, np.float32),
    batch_size=CONFIG['batch_size'], batches_per_epoch=CONFIG['batches_per_epoch'], combiner=combine_pad
)

# model
model = nn.Sequential(
    nn.Conv2d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(
        layers.ResBlock2d, downsample=nn.MaxPool2d(2, ceil_mode=True), upsample=nn.Identity,
        merge=lambda left, down: torch.cat(layers.interpolate_to_left(left, down, 'bilinear'), dim=1),
        structure=[
            [[8, 8, 8], [16, 8, 8]],
            [[8, 16, 16], [32, 16, 8]],
            [[16, 32, 32], [64, 32, 16]],
            [[32, 64, 64], [128, 64, 32]],
Example #4
0
train_dataset = cache_methods(
    ChangeSliceSpacing(dataset,
                       new_slice_spacing=CONFIG['source_slice_spacing']))

# cross validation
split = load_json(SPLIT_PATH)
train_ids, val_ids, test_ids = split[int(FOLD)]

# batch iterator
batch_iter = Infinite(
    load_by_random_id(train_dataset.load_image,
                      train_dataset.load_spacing,
                      train_dataset.load_gt,
                      ids=train_ids),
    unpack_args(get_random_slices, n_slices=CONFIG['n_slices']),
    random_apply(
        .5,
        unpack_args(lambda image, spacing, gt:
                    (np.flip(image, 1), spacing, np.flip(gt, 0)))),
    apply_at(1, np.float32),
    batch_size=CONFIG['batch_size'],
    batches_per_epoch=CONFIG['batches_per_epoch'],
    combiner=combine_pad)

# model
unet_2d = nn.Sequential(
    nn.Conv2d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(layers.ResBlock2d,
               downsample=nn.MaxPool2d(2, ceil_mode=True),
               upsample=nn.Identity,
               merge=lambda left, down: torch.cat(
                   layers.interpolate_to_left(left, down, 'bilinear'), dim=1),
Example #5
0
samples_per_train = 40 * 32000  # 32000 batches of size 40

paths = gather_train(args.input)
# cache the loader
if args.cache:
    load_pair = lru_cache(None)(load_pair)

# Check out https://deep-pipe.readthedocs.io/en/latest/tutorials/batch_iter.html
# for more details about the batch iterators we use

batch_iter = Infinite(
    # get a random pair of paths
    sample(paths),
    # load the image-contour pair
    unpack_args(load_pair),
    # get a random slice
    unpack_args(get_random_slice),
    # simple augmentation
    unpack_args(random_flip),
    batch_size=batch_size,
    batches_per_epoch=samples_per_train // (batch_size * n_epochs),
    combiner=combiner)

model = to_device(Network(), args.device)
optimizer = Adam(model.parameters(), lr=lr)

# Here we use a general training function with a custom `train_step`.
# See the tutorial for more details: https://deep-pipe.readthedocs.io/en/latest/tutorials/training.html

train(
    train_step,
    batch_iter,
Example #6
0
def train(root,
          dataset,
          generator,
          discriminator,
          latent_dim,
          disc_iters,
          batch_size,
          batches_per_epoch,
          n_epochs,
          r1_weight,
          lr_discriminator,
          lr_generator,
          device='cuda'):
    root = Path(root)
    generator, discriminator = generator.to(device), discriminator.to(device)

    gen_opt = torch.optim.Adam(generator.parameters(), lr_generator)
    disc_opt = torch.optim.Adam(discriminator.parameters(), lr_discriminator)

    batch_iter = Infinite(
        sample(dataset.ids),
        dataset.image,
        combiner=lambda images: [
            tuple(
                np.array(images).reshape(disc_iters,
                                         len(images) // disc_iters, 1, *images[
                                             0].shape))
        ],
        buffer_size=10,
        batch_size=batch_size * disc_iters,
        batches_per_epoch=batches_per_epoch,
    )

    def val():
        generator.eval()
        t = sample_on_sphere(latent_dim, 144, state=0).astype('float32')
        with torch.no_grad():
            y = to_np(generator(to_var(t, device=generator)))[:, 0]

        return {'generated__image': bytescale(stitch(y, bytescale))}

    logger = TBLogger(root / 'logs')
    train_base(
        train_step,
        batch_iter,
        n_epochs,
        logger=logger,
        validate=val,
        checkpoints=Checkpoints(root / 'checkpoints',
                                [generator, discriminator, gen_opt, disc_opt]),
        # lr=lr,
        generator=generator,
        discriminator=discriminator,
        gen_optimizer=gen_opt,
        disc_optimizer=disc_opt,
        latent_dim=latent_dim,
        r1_weight=r1_weight,
        time=TimeProfiler(logger.logger),
        tqdm=TQDM(False),
    )
    save_model_state(generator, root / 'generator.pth')
    save_model_state(discriminator, root / 'discriminator.pth')