Exemple #1
0
def add_augmentation_pipeline(
        pipeline,
        raw,
        simple=None,
        elastic=None,
        blur=None,
        noise=None):
    '''Add an augmentation pipeline to an existing pipeline.

    All optional arguments are kwargs for the corresponding augmentation node.
    If not given, those augmentations are not added.
    '''

    if simple is not None:
        pipeline = pipeline + gp.SimpleAugment(**simple)

    if elastic is not None:
        pipeline = pipeline + gp.ElasticAugment(**elastic)

    if blur is not None:
        pipeline = pipeline + Blur(raw, **blur)

    if noise is not None:
        pipeline = pipeline + gp.NoiseAugment(raw, **noise)

    return pipeline
Exemple #2
0
    def _augmentation_pipeline(self, raw, source):
        if 'elastic' in self.params and self.params['elastic']:
            source = source + gp.ElasticAugment(
                **self.params["elastic_params"])

        if 'blur' in self.params and self.params['blur']:
            source = source + Blur(raw, **self.params["blur_params"])

        if 'simple' in self.params and self.params['simple']:
            source = source + gp.SimpleAugment(**self.params["simple_params"])

        if 'noise' in self.params and self.params['noise']:
            source = source + gp.NoiseAugment(raw, **
                                              self.params['noise_params'])
        return source
Exemple #3
0
def random_point_pairs_pipeline(model,
                                loss,
                                optimizer,
                                dataset,
                                augmentation_parameters,
                                point_density,
                                out_dir,
                                normalize_factor=None,
                                checkpoint_interval=5000,
                                snapshot_interval=5000):

    raw_0 = gp.ArrayKey('RAW_0')
    points_0 = gp.GraphKey('POINTS_0')
    locations_0 = gp.ArrayKey('LOCATIONS_0')
    emb_0 = gp.ArrayKey('EMBEDDING_0')
    raw_1 = gp.ArrayKey('RAW_1')
    points_1 = gp.GraphKey('POINTS_1')
    locations_1 = gp.ArrayKey('LOCATIONS_1')
    emb_1 = gp.ArrayKey('EMBEDDING_1')

    # TODO parse this key from somewhere
    key = 'train/raw/0'

    data = daisy.open_ds(dataset.filename, key)
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    emb_voxel_size = voxel_size

    # Get in and out shape
    in_shape = gp.Coordinate(model.in_shape)
    out_shape = gp.Coordinate(model.out_shape)

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw_0, in_shape)
    request.add(raw_1, in_shape)
    request.add(points_0, out_shape)
    request.add(points_1, out_shape)
    request[locations_0] = gp.ArraySpec(nonspatial=True)
    request[locations_1] = gp.ArraySpec(nonspatial=True)

    snapshot_request = gp.BatchRequest()
    snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
    snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

    # Let's hardcode this for now
    # TODO read actual number from zarr file keys
    n_samples = 447
    batch_size = 1
    dim = 2
    padding = (100, 100)

    sources = []
    for i in range(n_samples):

        ds_key = f'train/raw/{i}'
        image_sources = tuple(
            gp.ZarrSource(
                dataset.filename, {raw: ds_key},
                {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) +
            gp.Pad(raw, None) for raw in [raw_0, raw_1])

        random_point_generator = RandomPointGenerator(density=point_density,
                                                      repetitions=2)

        point_sources = tuple(
            (RandomPointSource(points_0,
                               dim,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               dim,
                               random_point_generator=random_point_generator)))

        # TODO: get augmentation parameters from some config file!
        points_and_image_sources = tuple(
            (img_source, point_source) + gp.MergeProvider() + \
            gp.SimpleAugment() + \
            gp.ElasticAugment(
                spatial_dims=2,
                control_point_spacing=(10, 10),
                jitter_sigma=(0.0, 0.0),
                rotation_interval=(0, math.pi/2)) + \
            gp.IntensityAugment(r,
                                scale_min=0.8,
                                scale_max=1.2,
                                shift_min=-0.2,
                                shift_max=0.2,
                                clip=False) + \
            gp.NoiseAugment(r, var=0.01, clip=False)
            for r, img_source, point_source
            in zip([raw_0, raw_1], image_sources, point_sources))

        sample_source = points_and_image_sources + gp.MergeProvider()

        data = daisy.open_ds(dataset.filename, ds_key)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        sample_source += gp.Crop(raw_0, source_roi)
        sample_source += gp.Crop(raw_1, source_roi)
        sample_source += gp.Pad(raw_0, padding)
        sample_source += gp.Pad(raw_1, padding)
        sample_source += gp.RandomLocation()
        sources.append(sample_source)

    sources = tuple(sources)

    pipeline = sources + gp.RandomProvider()
    pipeline += gp.Unsqueeze([raw_0, raw_1])

    pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                             locations_1)

    # How does prepare batch relate to Stack?????
    pipeline += RejectArray(ensure_nonempty=locations_1)
    pipeline += RejectArray(ensure_nonempty=locations_0)

    # batch content
    # raw_0:          (1, h, w)
    # raw_1:          (1, h, w)
    # locations_0:    (n, 2)
    # locations_1:    (n, 2)

    pipeline += gp.Stack(batch_size)

    # batch content
    # raw_0:          (b, 1, h, w)
    # raw_1:          (b, 1, h, w)
    # locations_0:    (b, n, 2)
    # locations_1:    (b, n, 2)

    pipeline += gp.PreCache(num_workers=10)

    pipeline += gp.torch.Train(
        model,
        loss,
        optimizer,
        inputs={
            'raw_0': raw_0,
            'raw_1': raw_1
        },
        loss_inputs={
            'emb_0': emb_0,
            'emb_1': emb_1,
            'locations_0': locations_0,
            'locations_1': locations_1
        },
        outputs={
            2: emb_0,
            3: emb_1
        },
        array_specs={
            emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
            emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
        },
        checkpoint_basename=os.path.join(out_dir, 'model'),
        save_every=checkpoint_interval)

    pipeline += gp.Snapshot(
        {
            raw_0: 'raw_0',
            raw_1: 'raw_1',
            emb_0: 'emb_0',
            emb_1: 'emb_1',
            # locations_0 : 'locations_0',
            # locations_1 : 'locations_1',
        },
        every=snapshot_interval,
        additional_request=snapshot_request)

    return pipeline, request
def train(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names):

    # Network hyperparams
    INPUT_SHAPE = setup_config["INPUT_SHAPE"]
    OUTPUT_SHAPE = setup_config["OUTPUT_SHAPE"]

    # Skeleton generation hyperparams
    SKEL_GEN_RADIUS = setup_config["SKEL_GEN_RADIUS"]
    THETAS = np.array(setup_config["THETAS"]) * math.pi
    SPLIT_PS = setup_config["SPLIT_PS"]
    NOISE_VAR = setup_config["NOISE_VAR"]
    N_OBJS = setup_config["N_OBJS"]

    # Skeleton variation hyperparams
    LABEL_RADII = setup_config["LABEL_RADII"]
    RAW_RADII = setup_config["RAW_RADII"]
    RAW_INTENSITIES = setup_config["RAW_INTENSITIES"]

    # Training hyperparams
    CACHE_SIZE = setup_config["CACHE_SIZE"]
    NUM_WORKERS = setup_config["NUM_WORKERS"]
    SNAPSHOT_EVERY = setup_config["SNAPSHOT_EVERY"]
    CHECKPOINT_EVERY = setup_config["CHECKPOINT_EVERY"]

    point_trees = gp.PointsKey("POINT_TREES")
    labels = gp.ArrayKey("LABELS")
    raw = gp.ArrayKey("RAW")
    gt_fg = gp.ArrayKey("GT_FG")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # tensorflow tensors
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    request = gp.BatchRequest()
    request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(point_trees, INPUT_SHAPE)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, INPUT_SHAPE)
    snapshot_request.add(embedding,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(maxima,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(gradient_embedding,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(gradient_fg,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request[emst] = gp.ArraySpec()
    snapshot_request[edges_u] = gp.ArraySpec()
    snapshot_request[edges_v] = gp.ArraySpec()
    snapshot_request[ratio_pos] = gp.ArraySpec()
    snapshot_request[ratio_neg] = gp.ArraySpec()
    snapshot_request[dist] = gp.ArraySpec()
    snapshot_request[num_pos_pairs] = gp.ArraySpec()
    snapshot_request[num_neg_pairs] = gp.ArraySpec()

    pipeline = (
        nl.SyntheticLightLike(
            point_trees,
            dims=2,
            r=SKEL_GEN_RADIUS,
            n_obj=N_OBJS,
            thetas=THETAS,
            split_ps=SPLIT_PS,
        )
        # + gp.SimpleAugment()
        # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2)
        + nl.RasterizeSkeleton(
            point_trees,
            raw,
            gp.ArraySpec(
                roi=gp.Roi((None, ) * 2, (None, ) * 2),
                voxel_size=gp.Coordinate((1, 1)),
                dtype=np.uint64,
            ),
        ) + nl.RasterizeSkeleton(
            point_trees,
            labels,
            gp.ArraySpec(
                roi=gp.Roi((None, ) * 2, (None, ) * 2),
                voxel_size=gp.Coordinate((1, 1)),
                dtype=np.uint64,
            ),
            use_component=True,
            n_objs=int(setup_config["HIDE_SIGNAL"]),
        ) + nl.GrowLabels(labels, radii=LABEL_RADII) +
        nl.GrowLabels(raw, radii=RAW_RADII) +
        LabelToFloat32(raw, intensities=RAW_INTENSITIES) +
        gp.NoiseAugment(raw, var=NOISE_VAR) +
        gp.PreCache(cache_size=CACHE_SIZE, num_workers=NUM_WORKERS) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                "strided_slice_1:0": maxima,
                "gt_fg:0": gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=CHECKPOINT_EVERY,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.Snapshot(
            output_filename="{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                labels: "volumes/labels",
                point_trees: "point_trees",
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
            },
            dataset_dtypes={
                maxima: np.float32,
                gt_fg: np.float32
            },
            every=SNAPSHOT_EVERY,
            additional_request=snapshot_request,
        )
        # + gp.PrintProfilingStats(every=100)
    )

    with gp.build(pipeline):
        for i in range(n_iterations + 1):
            pipeline.request_batch(request)
            request._update_random_seed()
Exemple #5
0
def train(n_iterations):

    point_trees = gp.PointsKey("POINT_TREES")
    labels = gp.ArrayKey("LABELS")
    raw = gp.ArrayKey("RAW")
    # gt_fg = gp.ArrayKey("GT_FG")
    # embedding = gp.ArrayKey("EMBEDDING")
    # fg = gp.ArrayKey("FG")
    # maxima = gp.ArrayKey("MAXIMA")
    # gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    # gradient_fg = gp.ArrayKey("GRADIENT_FG")
    # emst = gp.ArrayKey("EMST")
    # edges_u = gp.ArrayKey("EDGES_U")
    # edges_v = gp.ArrayKey("EDGES_V")

    request = gp.BatchRequest()
    request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(point_trees, INPUT_SHAPE)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, INPUT_SHAPE)
    # snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(
    #     gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))
    # )
    # snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()

    pipeline = (
        nl.SyntheticLightLike(
            point_trees,
            dims=2,
            r=SKEL_GEN_RADIUS,
            n_obj=N_OBJS,
            thetas=THETAS,
            split_ps=SPLIT_PS,
        )
        # + gp.SimpleAugment()
        # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2)
        + nl.RasterizeSkeleton(
            point_trees,
            labels,
            gp.ArraySpec(
                roi=gp.Roi((None,) * 2, (None,) * 2),
                voxel_size=gp.Coordinate((1, 1)),
                dtype=np.uint64,
            ),
        )
        + gp.Copy(labels, raw)
        + nl.GrowLabels(labels, radii=LABEL_RADII)
        + nl.GrowLabels(raw, radii=RAW_RADII)
        + LabelToFloat32(raw, intensities=RAW_INTENSITIES)
        + gp.NoiseAugment(raw, var=NOISE_VAR)
        # + gp.PreCache(cache_size=40, num_workers=10)
        # + gp.tensorflow.Train(
        #     "train_net",
        #     optimizer=add_loss,
        #     loss=None,
        #     inputs={tensor_names["raw"]: raw, tensor_names["gt_labels"]: labels},
        #     outputs={
        #         tensor_names["embedding"]: embedding,
        #         tensor_names["fg"]: fg,
        #         "maxima:0": maxima,
        #         "gt_fg:0": gt_fg,
        #         emst_name: emst,
        #         edges_u_name: edges_u,
        #         edges_v_name: edges_v,
        #     },
        #     gradients={
        #         tensor_names["embedding"]: gradient_embedding,
        #         tensor_names["fg"]: gradient_fg,
        #     },
        # )
        + gp.Snapshot(
            output_filename="{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                labels: "volumes/labels",
                point_trees: "point_trees",
                # embedding: "volumes/embedding",
                # fg: "volumes/fg",
                # maxima: "volumes/maxima",
                # gt_fg: "volumes/gt_fg",
                # gradient_embedding: "volumes/gradient_embedding",
                # gradient_fg: "volumes/gradient_fg",
                # emst: "emst",
                # edges_u: "edges_u",
                # edges_v: "edges_v",
            },
            # dataset_dtypes={maxima: np.float32, gt_fg: np.float32},
            every=100,
            additional_request=snapshot_request,
        )
        + gp.PrintProfilingStats(every=10)
    )

    with gp.build(pipeline):
        for i in range(n_iterations):
            pipeline.request_batch(request)