def __init__(self, structure, n_chans_in, n_chans_out): super().__init__() first_channels = structure[0][0][0] last_channels = structure[0][-1][-1] self.unet = nn.Sequential( nn.Conv2d(n_chans_in, first_channels, kernel_size=3, padding=1, bias=False), layers.FPN(layers.ResBlock2d, nn.MaxPool2d(kernel_size=2, ceil_mode=True), nn.Identity(), lambda left, down: torch.add( *layers.interpolate_to_left(left, down, 1)), structure, kernel_size=3, padding=1)) self.out = nn.Sequential( layers.ResBlock2d(last_channels, last_channels, kernel_size=3, padding=1), layers.PreActivation2d(last_channels, n_chans_out, kernel_size=1))
def __init__(self): super().__init__() self.init_block = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), layers.PreActivation2d(32, 32, kernel_size=3, padding=1), ) # don't need upsampling here because it will be done in `merge_branches` upsample = nn.Sequential # same as nn.Identity, but supported by older versions downsample = partial(nn.MaxPool2d, kernel_size=2) self.midline_head = layers.InterpolateToInput( nn.Sequential( downsample(), # unet layers.FPN(layers.ResBlock2d, downsample=downsample, upsample=upsample, merge=merge_branches, structure=STRUCTURE, kernel_size=3, padding=1), # final logits layers.PreActivation2d(32, 1, kernel_size=1), ), mode='bilinear') self.limits_head = layers.InterpolateToInput( nn.Sequential( layers.ResBlock2d(32, 32, kernel_size=3, padding=1), downsample(), layers.ResBlock2d(32, 32, kernel_size=3, padding=1), # 2D feature map to 1D feature map nn.AdaptiveMaxPool2d((None, 1)), layers.Reshape('0', '1', -1), nn.Conv1d(32, 8, kernel_size=3, padding=1), nn.ReLU(), nn.Conv1d(8, 1, kernel_size=3, padding=1)), mode='linear', axes=0)
load_by_random_id(train_dataset.load_image, train_dataset.load_gt, ids=train_ids), unpack_args(tumor_sampling, patch_size=PATCH_SIZE, tumor_p=.5), random_apply(.5, unpack_args(lambda image, gt: (np.flip(image, 1), np.flip(gt, 0)))), batch_size=CONFIG['batch_size'], batches_per_epoch=CONFIG['batches_per_epoch'] ) # model model = nn.Sequential( nn.Conv3d(dataset.n_modalities, 8, kernel_size=3, padding=1), layers.FPN( layers.ResBlock3d, downsample=nn.MaxPool3d(2, ceil_mode=True), upsample=nn.Identity, merge=lambda left, down: torch.cat(layers.interpolate_to_left(left, down, 'trilinear'), dim=1), structure=[ [[8, 8, 8], [16, 8, 8]], [[8, 16, 16], [32, 16, 8]], [[16, 32, 32], [64, 32, 16]], [[32, 64, 64], [128, 64, 32]], [[64, 128, 128], [256, 128, 64]], [[128, 256, 256], [512, 256, 128]], [256, 512, 256] ], kernel_size=3, padding=1 ), layers.PreActivation3d(8, dataset.n_classes, kernel_size=1) ).to(CONFIG['device']) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr']) # predict @patches_grid(patch_size=PATCH_SIZE, stride=7 * PATCH_SIZE // 8)