def __init__(self, structure, n_chans_in, n_chans_out): super().__init__() first_channels = structure[0][0][0] last_channels = structure[0][-1][-1] self.unet = nn.Sequential( nn.Conv2d(n_chans_in, first_channels, kernel_size=3, padding=1, bias=False), layers.FPN(layers.ResBlock2d, nn.MaxPool2d(kernel_size=2, ceil_mode=True), nn.Identity(), lambda left, down: torch.add( *layers.interpolate_to_left(left, down, 1)), structure, kernel_size=3, padding=1)) self.out = nn.Sequential( layers.ResBlock2d(last_channels, last_channels, kernel_size=3, padding=1), layers.PreActivation2d(last_channels, n_chans_out, kernel_size=1))
train_ids, val_ids, test_ids = split[int(FOLD)] # batch iterator batch_iter = Infinite( load_by_random_id(train_dataset.load_image, train_dataset.load_gt, ids=train_ids), unpack_args(tumor_sampling, patch_size=PATCH_SIZE, tumor_p=.5), random_apply(.5, unpack_args(lambda image, gt: (np.flip(image, 1), np.flip(gt, 0)))), batch_size=CONFIG['batch_size'], batches_per_epoch=CONFIG['batches_per_epoch'] ) # model model = nn.Sequential( nn.Conv3d(dataset.n_modalities, 8, kernel_size=3, padding=1), layers.FPN( layers.ResBlock3d, downsample=nn.MaxPool3d(2, ceil_mode=True), upsample=nn.Identity, merge=lambda left, down: torch.cat(layers.interpolate_to_left(left, down, 'trilinear'), dim=1), structure=[ [[8, 8, 8], [16, 8, 8]], [[8, 16, 16], [32, 16, 8]], [[16, 32, 32], [64, 32, 16]], [[32, 64, 64], [128, 64, 32]], [[64, 128, 128], [256, 128, 64]], [[128, 256, 256], [512, 256, 128]], [256, 512, 256] ], kernel_size=3, padding=1 ), layers.PreActivation3d(8, dataset.n_classes, kernel_size=1) ).to(CONFIG['device']) criterion = nn.CrossEntropyLoss()
random_apply( .5, unpack_args(lambda image, gt: (np.flip(image, 1), np.flip(gt, 0)))), apply_at(1, prepend_dims), apply_at(1, np.float32), batch_size=CONFIG['batch_size'], batches_per_epoch=CONFIG['batches_per_epoch']) # model model = nn.Sequential( nn.Conv3d(dataset.n_modalities, 8, kernel_size=3, padding=1), layers.FPN(layers.ResBlock3d, downsample=nn.MaxPool3d(2, ceil_mode=True), upsample=nn.Identity, merge=lambda left, down: torch.cat( layers.interpolate_to_left(left, down, 'trilinear'), dim=1), structure=[[[8, 8, 8], [16, 8, 8]], [[8, 16, 16], [32, 16, 8]], [[16, 32, 32], [64, 32, 16]], [[32, 64, 64], [128, 64, 32]], [[64, 128, 128], [256, 128, 64]], [[128, 256, 256], [512, 256, 128]], [256, 512, 256]], kernel_size=3, padding=1), layers.PreActivation3d(8, 1, kernel_size=1)).to(CONFIG['device']) criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr']) # predict @patches_grid(patch_size=PATCH_SIZE, stride=7 * PATCH_SIZE // 8)