コード例 #1
0
raw_dataset = BinaryGT(BraTS2013(BRATS_PATH), positive_classes=CONFIG['positive_classes'])
dataset = cache_methods(ZooOfSpacings(
    apply(CropToBrain(raw_dataset), load_image=partial(min_max_scale, axes=0)),
    slice_spacings=CONFIG['slice_spacings']
))

# cross validation
split = load_json(SPLIT_PATH)
train_ids, val_ids, test_ids = split[int(FOLD)]

# batch iterator
batch_iter = Infinite(
    load_by_random_id(dataset.load_image, dataset.load_gt, ids=train_ids),
    unpack_args(get_random_slice),
    random_apply(.5, unpack_args(lambda image, gt: (np.flip(image, 1), np.flip(gt, 0)))),
    apply_at(1, prepend_dims),
    apply_at(1, np.float32),
    batch_size=CONFIG['batch_size'], batches_per_epoch=CONFIG['batches_per_epoch'], combiner=combine_pad
)

# model
model = nn.Sequential(
    nn.Conv2d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(
        layers.ResBlock2d, downsample=nn.MaxPool2d(2, ceil_mode=True), upsample=nn.Identity,
        merge=lambda left, down: torch.cat(layers.interpolate_to_left(left, down, 'bilinear'), dim=1),
        structure=[
            [[8, 8, 8], [16, 8, 8]],
            [[8, 16, 16], [32, 16, 8]],
            [[16, 32, 32], [64, 32, 16]],
            [[32, 64, 64], [128, 64, 32]],
コード例 #2
0
ファイル: unet_2_5d.py プロジェクト: migonch/two_and_half_d
# cross validation
split = load_json(SPLIT_PATH)
train_ids, val_ids, test_ids = split[int(FOLD)]

# batch iterator
batch_iter = Infinite(
    load_by_random_id(train_dataset.load_image,
                      train_dataset.load_spacing,
                      train_dataset.load_gt,
                      ids=train_ids),
    unpack_args(get_random_slices, n_slices=CONFIG['n_slices']),
    random_apply(
        .5,
        unpack_args(lambda image, spacing, gt:
                    (np.flip(image, 1), spacing, np.flip(gt, 0)))),
    apply_at(1, np.float32),
    batch_size=CONFIG['batch_size'],
    batches_per_epoch=CONFIG['batches_per_epoch'],
    combiner=combine_pad)

# model
unet_2d = nn.Sequential(
    nn.Conv2d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(layers.ResBlock2d,
               downsample=nn.MaxPool2d(2, ceil_mode=True),
               upsample=nn.Identity,
               merge=lambda left, down: torch.cat(
                   layers.interpolate_to_left(left, down, 'bilinear'), dim=1),
               structure=[[[8, 8, 8], [16, 8, 8]], [[8, 16, 16], [32, 16, 8]],
                          [[16, 32, 32], [64, 32, 16]],
                          [[32, 64, 64], [128, 64, 32]],
コード例 #3
0
# cross validation
split = load_json(SPLIT_PATH)
train_ids, val_ids, test_ids = split[int(FOLD)]

# batch iterator
batch_iter = Infinite(
    load_by_random_id(dataset.load_image,
                      dataset.load_spacing,
                      dataset.load_gt,
                      ids=train_ids),
    unpack_args(get_random_slices, n_slices=CONFIG['n_slices']),
    random_apply(
        .5,
        unpack_args(lambda image, spacing, gt:
                    (np.flip(image, 1), spacing, np.flip(gt, 0)))),
    apply_at(2, prepend_dims),
    apply_at([1, 2], np.float32),
    batch_size=CONFIG['batch_size'],
    batches_per_epoch=CONFIG['batches_per_epoch'],
    combiner=combine_pad)

# model
unet_2d = nn.Sequential(
    nn.Conv2d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(layers.ResBlock2d,
               downsample=nn.MaxPool2d(2, ceil_mode=True),
               upsample=nn.Identity,
               merge=lambda left, down: torch.cat(
                   layers.interpolate_to_left(left, down, 'bilinear'), dim=1),
               structure=[[[8, 8, 8], [16, 8, 8]], [[8, 16, 16], [32, 16, 8]],
                          [[16, 32, 32], [64, 32, 16]],
コード例 #4
0
 def wrapper(binary_gt_and_spacing, proba):
     return metric(
         apply_at(0, np.bool_)(binary_gt_and_spacing), proba >= .5)