Example #1
0
    def __init__(self, config):
        super(ForegroundUnet, self).__init__()

        # CONFIGS
        self.input_shape = tuple(config["INPUT_SHAPE"])
        self.output_shape = tuple(config["OUTPUT_SHAPE"])
        self.voxel_size = tuple(config["VOXEL_SIZE"])
        self.num_fmaps = config["NUM_FMAPS_FOREGROUND"]
        self.fmap_inc_factors = config["FMAP_INC_FACTORS_FOREGROUND"]
        self.downsample_factors = tuple(
            tuple(x) for x in config["DOWNSAMPLE_FACTORS"])
        self.kernel_size_up = config["KERNEL_SIZE_UP"]
        self.activation = config["ACTIVATION"]

        # LAYERS
        self.unet = UNet(
            in_channels=2,
            num_fmaps=self.num_fmaps,
            fmap_inc_factors=self.fmap_inc_factors,
            downsample_factors=self.downsample_factors,
            kernel_size_down=self.kernel_size_up,
            kernel_size_up=self.kernel_size_up,
            activation=self.activation,
            fov=tuple(a * b
                      for a, b in zip(self.input_shape, self.voxel_size)),
            voxel_size=self.voxel_size,
            num_heads=1,
            constant_upsample=True,
        )
        self.conv_layer = ConvPass(
            in_channels=self.num_fmaps,
            out_channels=1,
            kernel_sizes=[[1, 1, 1]],
            activation="Sigmoid",
        )
Example #2
0
    def __init__(self, config):
        super(EmbeddingUnet, self).__init__()

        # CONFIGS
        self.input_shape = tuple(config["INPUT_SHAPE"])
        self.output_shape = tuple(config["OUTPUT_SHAPE"])
        self.embedding_dims = config["EMBEDDING_DIMS"]
        self.num_fmaps = config["NUM_FMAPS_EMBEDDING"]
        self.fmap_inc_factors = config["FMAP_INC_FACTORS_EMBEDDING"]
        self.downsample_factors = tuple(
            tuple(x) for x in config["DOWNSAMPLE_FACTORS"])
        self.kernel_size_up = config["KERNEL_SIZE_UP"]
        self.voxel_size = config["VOXEL_SIZE"]
        self.activation = config["ACTIVATION"]
        self.normalize_embeddings = config["NORMALIZE_EMBEDDINGS"]
        self.aux_task = config["AUX_TASK"]
        self.neighborhood = config["AUX_NEIGHBORHOOD"]

        # LAYERS
        # UNET
        self.unet = UNet(
            in_channels=2,
            num_fmaps=self.num_fmaps,
            fmap_inc_factors=self.fmap_inc_factors,
            downsample_factors=self.downsample_factors,
            kernel_size_down=self.kernel_size_up,
            kernel_size_up=self.kernel_size_up,
            activation=self.activation,
            fov=self.input_shape,
            num_heads=1,
            constant_upsample=True,
        )
        # FINAL CONV LAYER
        self.conv_layer = ConvPass(
            in_channels=self.num_fmaps,
            out_channels=self.embedding_dims,
            kernel_sizes=[[1 for _ in self.input_shape]],
            activation="Tanh",
        )
        # AUX LAYER
        if self.aux_task:
            self.aux_layer = ConvPass(
                in_channels=self.num_fmaps,
                out_channels=self.neighborhood,
                kernel_sizes=[[1 for _ in self.input_shape]],
                activation="Tanh",
            )
Example #3
0
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# cages and render parameters
cage1 = Cage("../data/example_cage", 1)
psf = GaussianPSF(intensity=0.125, sigma=(1.0, 1.0))
min_density = 1e-5
max_density = 1e-5

# create model, loss, and optimizer
unet = UNet(
    in_channels=1,
    num_fmaps=24,  # this needs to be increased later (24)
    fmap_inc_factor=4,  # this needs to be increased later (3)
    downsample_factors=[
        [1, 2, 2],
        [1, 2, 2],
        [1, 2, 2],
    ],
    kernel_size_down=[[[3, 3, 3], [3, 3, 3]]] * 4,
    kernel_size_up=[[[3, 3, 3], [3, 3, 3]]] * 3,
    padding='valid')
model = torch.nn.Sequential(unet,
                            ConvPass(24, 1, [(1, 1, 1)], activation='Sigmoid'))
loss = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

# declare gunpowder arrays
raw = gp.ArrayKey('RAW')
seg = gp.ArrayKey('SEGMENTATION')
out_cage_map = gp.ArrayKey('OUT_CAGE_MAP')
out_density_map = gp.ArrayKey('OUT_DENSITY_MAP')
Example #4
0
def train_until(max_iteration):

    in_channels = 1
    num_fmaps = 12
    fmap_inc_factors = 6
    downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)]

    unet = UNet(in_channels,
                num_fmaps,
                fmap_inc_factors,
                downsample_factors,
                constant_upsample=True)

    model = Convolve(unet, 12, 1)

    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

    # start of gunpowder part:

    raw = gp.ArrayKey('RAW')
    points = gp.GraphKey('POINTS')
    groundtruth = gp.ArrayKey('RASTER')
    prediction = gp.ArrayKey('PRED_POINT')
    grad = gp.ArrayKey('GRADIENT')

    voxel_size = gp.Coordinate((40, 4, 4))

    input_shape = (96, 430, 430)
    output_shape = (60, 162, 162)

    input_size = gp.Coordinate(input_shape) * voxel_size
    output_size = gp.Coordinate(output_shape) * voxel_size

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(points, output_size)
    request.add(groundtruth, output_size)
    request.add(prediction, output_size)
    request.add(grad, output_size)

    pos_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddCenterPoint(points, raw) + gp.Pad(raw, None) +
        gp.RandomLocation(ensure_nonempty=points)
        for filename in pos_samples) + gp.RandomProvider()
    neg_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddNoPoint(points, raw) + gp.RandomLocation()
        for filename in neg_samples) + gp.RandomProvider()

    data_sources = (pos_sources, neg_sources)
    data_sources += gp.RandomProvider(probabilities=[0.9, 0.1])
    data_sources += gp.Normalize(raw)

    train_pipeline = data_sources
    train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40],
                                        jitter_sigma=[0, 2, 2],
                                        rotation_interval=[0, math.pi / 2.0],
                                        prob_slip=0.05,
                                        prob_shift=0.05,
                                        max_misalign=10,
                                        subsample=8)
    train_pipeline += gp.SimpleAugment(transpose_only=[1, 2])

    train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \
            z_section_wise=True)
    train_pipeline += gp.RasterizePoints(
        points,
        groundtruth,
        array_spec=gp.ArraySpec(voxel_size=voxel_size),
        settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak'))
    train_pipeline += gp.PreCache(cache_size=40, num_workers=10)

    train_pipeline += Reshape(raw, (1, 1) + input_shape)
    train_pipeline += Reshape(groundtruth, (1, 1) + output_shape)

    train_pipeline += gp_torch.Train(model=model,
                                     loss=loss,
                                     optimizer=optimizer,
                                     inputs={'x': raw},
                                     outputs={0: prediction},
                                     loss_inputs={
                                         0: prediction,
                                         1: groundtruth
                                     },
                                     gradients={0: grad},
                                     save_every=1000,
                                     log_dir='log')

    train_pipeline += Reshape(raw, input_shape)
    train_pipeline += Reshape(groundtruth, output_shape)
    train_pipeline += Reshape(prediction, output_shape)
    train_pipeline += Reshape(grad, output_shape)

    train_pipeline += gp.Snapshot(
        {
            raw: 'volumes/raw',
            groundtruth: 'volumes/groundtruth',
            prediction: 'volumes/prediction',
            grad: 'volumes/gradient'
        },
        every=500,
        output_filename='test_{iteration}.hdf')
    train_pipeline += gp.PrintProfilingStats(every=10)

    with gp.build(train_pipeline):
        for i in range(max_iteration):
            train_pipeline.request_batch(request)
Example #5
0
def train_until(max_iteration):

    in_channels = 1
    num_fmaps = 12
    fmap_inc_factors = 6
    downsample_factors = [(2, 2, 2), (2, 2, 2), (3, 3, 3)]

    unet = UNet(in_channels, num_fmaps, fmap_inc_factors, downsample_factors)

    model = Convolve(unet, num_fmaps, 3)

    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.5e-4,
                                 betas=(0.95, 0.999))

    test_input_shape = Coordinate((196, ) * 3)
    test_output_shape = Coordinate((84, ) * 3)

    raw = ArrayKey('RAW')
    labels = ArrayKey('GT_LABELS')
    labels_mask = ArrayKey('GT_LABELS_MASK')
    affs = ArrayKey('PREDICTED_AFFS')
    gt_affs = ArrayKey('GT_AFFS')
    gt_affs_scale = ArrayKey('GT_AFFS_SCALE')
    affs_gradient = ArrayKey('AFFS_GRADIENT')

    voxel_size = Coordinate((8, ) * 3)
    input_size = Coordinate(test_input_shape) * voxel_size
    output_size = Coordinate(test_output_shape) * voxel_size

    #max labels padding calculated
    labels_padding = Coordinate((376, 536, 536))

    request = BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_mask, output_size)
    request.add(gt_affs, output_size)
    request.add(gt_affs_scale, output_size)

    snapshot_request = BatchRequest({affs: request[gt_affs]})

    data_sources = tuple(
        ZarrSource(
            os.path.join(data_dir, sample), {
                raw: 'volumes/raw',
                labels: 'volumes/labels/neuron_ids',
                labels_mask: 'volumes/labels/mask',
            }, {
                raw: ArraySpec(interpolatable=True),
                labels: ArraySpec(interpolatable=False),
                labels_mask: ArraySpec(interpolatable=False)
            }) + Normalize(raw) + Pad(raw, None) +
        Pad(labels, labels_padding) + Pad(labels_mask, labels_padding) +
        RandomLocation(min_masked=0.5, mask=labels_mask) for sample in samples)

    train_pipeline = data_sources

    train_pipeline += RandomProvider(probabilities=probabilities)

    train_pipeline += ElasticAugment(control_point_spacing=[40, 40, 40],
                                     jitter_sigma=[0, 0, 0],
                                     rotation_interval=[0, math.pi / 2.0],
                                     prob_slip=0,
                                     prob_shift=0,
                                     max_misalign=0,
                                     subsample=8)

    train_pipeline += SimpleAugment()

    train_pipeline += ElasticAugment(control_point_spacing=[40, 40, 40],
                                     jitter_sigma=[2, 2, 2],
                                     rotation_interval=[0, math.pi / 2.0],
                                     prob_slip=0.01,
                                     prob_shift=0.01,
                                     max_misalign=1,
                                     subsample=8)

    train_pipeline += IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1)

    train_pipeline += GrowBoundary(labels, labels_mask, steps=1)

    train_pipeline += AddAffinities(neighborhood,
                                    labels=labels,
                                    affinities=gt_affs)

    train_pipeline += BalanceLabels(gt_affs, gt_affs_scale)

    train_pipeline += IntensityScaleShift(raw, 2, -1)

    train_pipeline += Normalize(gt_affs)

    train_pipeline += Unsqueeze([raw, gt_affs])
    train_pipeline += Unsqueeze([raw])

    train_pipeline += PreCache(cache_size=40, num_workers=10)

    train_pipeline += Train(model=model,
                            loss=loss,
                            optimizer=optimizer,
                            inputs={'x': raw},
                            loss_inputs={
                                0: affs,
                                1: gt_affs
                            },
                            outputs={0: affs},
                            save_every=1000,
                            log_dir='log')

    train_pipeline += Squeeze([raw])
    train_pipeline += Squeeze([raw, gt_affs, affs])

    train_pipeline += IntensityScaleShift(raw, 0.5, 0.5)

    train_pipeline += Snapshot(
        {
            raw: 'volumes/raw',
            labels: 'volumes/labels/neuron_ids',
            gt_affs: 'volumes/gt_affinities',
            affs: 'volumes/pred_affinities',
            labels_mask: 'volumes/labels/mask'
        },
        dataset_dtypes={
            labels: np.uint64,
            gt_affs: np.float32
        },
        every=1,
        output_filename='batch_{iteration}.zarr',
        additional_request=snapshot_request)

    train_pipeline += PrintProfilingStats(every=1)

    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)