Beispiel #1
0
def get_test_data_sources(setup_config):

    input_shape = Coordinate(setup_config["INPUT_SHAPE"])
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    input_size = input_shape * voxel_size

    micron_scale = voxel_size[0]

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    matched = GraphKey("MATCHED")
    nonempty_placeholder = GraphKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    ensure_nonempty = matched

    data_sources = ((
        TestImageSource(
            array=raw,
            array_specs={
                raw:
                ArraySpec(interpolatable=True,
                          voxel_size=voxel_size,
                          dtype=np.uint16)
            },
            size=input_size * 3,
            voxel_size=voxel_size,
        ),
        TestPointSource(
            points=[matched, nonempty_placeholder],
            directed=False,
            size=input_size * 3,
            num_points=333,
        ),
    ) + MergeProvider() + RandomLocation(
        ensure_nonempty=ensure_nonempty,
        ensure_centered=True,
        point_balance_radius=10 * micron_scale,
    ) + RasterizeSkeleton(
        points=matched,
        array=labels,
        array_spec=ArraySpec(
            interpolatable=False, voxel_size=voxel_size, dtype=np.uint64),
    ) + Normalize(raw))

    return (
        data_sources,
        raw,
        labels,
        nonempty_placeholder,
        matched,
    )
Beispiel #2
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Beispiel #3
0
def get_test_data_sources(setup_config):

    input_shape = Coordinate(setup_config["INPUT_SHAPE"])
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    input_size = input_shape * voxel_size

    micron_scale = voxel_size[0]

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    consensus = GraphKey("CONSENSUS")
    skeletonization = GraphKey("SKELETONIZATION")
    matched = GraphKey("MATCHED")
    nonempty_placeholder = GraphKey("NONEMPTY")
    labels = ArrayKey("LABELS")
    use_gurobi = setup_config["USE_GUROBI"]

    if setup_config["FUSION_PIPELINE"]:
        ensure_nonempty = nonempty_placeholder
    else:
        ensure_nonempty = consensus

    data_sources = ((
        TestImageSource(
            array=raw,
            array_specs={
                raw:
                ArraySpec(interpolatable=True,
                          voxel_size=voxel_size,
                          dtype=np.uint16)
            },
            size=input_size * 3,
            voxel_size=voxel_size,
        ),
        TestPointSource(
            points=[consensus, nonempty_placeholder],
            directed=True,
            size=input_size * 3,
            num_points=30,
        ),
        TestPointSource(
            points=[skeletonization],
            directed=False,
            size=input_size * 3,
            num_points=333,
        ),
    ) + MergeProvider() + RandomLocation(
        ensure_nonempty=ensure_nonempty,
        ensure_centered=True,
        point_balance_radius=10 * micron_scale,
    ) + TopologicalMatcher(
        skeletonization,
        consensus,
        matched,
        match_distance_threshold=50 * micron_scale,
        max_gap_crossing=30 * micron_scale,
        try_complete=False,
        use_gurobi=use_gurobi,
        expected_edge_len=1000,
    ) + RasterizeSkeleton(
        points=matched,
        array=labels,
        array_spec=ArraySpec(
            interpolatable=False, voxel_size=voxel_size, dtype=np.uint64),
    ) + Normalize(raw))

    return (
        data_sources,
        raw,
        labels,
        nonempty_placeholder,
        matched,
    )
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names,
                            loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    max_label_dist = setup_config["MAX_LABEL_DIST"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    dist = gp.ArrayKey("DIST")
    dist_mask = gp.ArrayKey("DIST_MASK")
    dist_cropped = gp.ArrayKey("DIST_CROPPED")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    fg_dist = gp.ArrayKey("FG_DIST")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # add request
    request = gp.BatchRequest()
    request.add(dist_mask, output_size)
    request.add(dist_cropped, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(dist, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()

    # tensorflow requests
    snapshot_request.add(raw, input_size)  # input_size request for positioning
    snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched, center_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + gp.contrib.nodes.add_distance.AddDistance(
            labels,
            dist,
            dist_mask,
            max_distance=max_label_dist * micron_scale) + gp.contrib.nodes.
        tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1)
        + ThresholdMask(dist, loss_weights, 1e-4)
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi /
             2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(dist, dist_cropped)
        # + gp.PreCache(cache_size=cache_size, num_workers=num_workers)
        + gp.tensorflow.Train(
            "train_net_foreground",
            optimizer=mknet_tensor_names["optimizer"],
            loss=mknet_tensor_names["fg_loss"],
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_distances"]: dist_cropped,
                mknet_tensor_names["loss_weights"]: loss_weights,
            },
            outputs={mknet_tensor_names["fg_pred"]: fg_dist},
            gradients={mknet_tensor_names["fg_pred"]: gradient_fg},
            save_every=checkpoint_every,
            # summary=mknet_tensor_names["summaries"],
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                labels: "volumes/labels",
                # labeled data
                dist_cropped: "volumes/dist",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                fg_dist: "volumes/fg_dist",
                gradient_fg: "volumes/gradient_fg",
                # output debug data
                dist_mask: "volumes/dist_mask",
                loss_weights: "volumes/loss_weights"
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Beispiel #5
0
def get_mouselight_data_sources(setup_config: Dict[str, Any],
                                source_samples: List[str],
                                locations=False):
    # Source Paths and accessibility
    raw_n5 = setup_config["RAW_N5"]
    mongo_url = setup_config["MONGO_URL"]
    samples_path = Path(setup_config["SAMPLES_PATH"])

    # specified_locations = setup_config.get("SPECIFIED_LOCATIONS")

    # Graph matching parameters
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"]
    matching_failures_dir = (matching_failures_dir
                             if matching_failures_dir is None else
                             Path(matching_failures_dir))

    # Data Properties
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    output_size = output_shape * voxel_size
    micron_scale = voxel_size[0]

    distance_attr = setup_config["DISTANCE_ATTRIBUTE"]
    target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"])
    max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"])

    mongo_db_template = setup_config["MONGO_DB_TEMPLATE"]
    matched_source = setup_config.get("MATCHED_SOURCE", "matched")

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    matched = gp.PointsKey("MATCHED")
    nonempty_placeholder = gp.PointsKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    ensure_nonempty = nonempty_placeholder

    node_offset = {
        sample.name: (daisy.persistence.MongoDbGraphProvider(
            mongo_db_template.format(sample=sample.name,
                                     source="skeletonization"),
            mongo_url,
        ).num_nodes(None) + 1)
        for sample in samples_path.iterdir() if sample.name in source_samples
    }

    # if specified_locations is not None:
    #     centers = pickle.load(open(specified_locations, "rb"))
    #     random = gp.SpecifiedLocation
    #     kwargs = {"locations": centers, "choose_randomly": True}
    #     logger.info(f"Using specified locations from {specified_locations}")
    # elif locations:
    #     random = RandomLocations
    #     kwargs = {
    #         "ensure_nonempty": ensure_nonempty,
    #         "ensure_centered": True,
    #         "point_balance_radius": point_balance_radius * micron_scale,
    #         "loc": gp.ArrayKey("RANDOM_LOCATION"),
    #     }
    # else:

    random = RandomLocation
    kwargs = {
        "ensure_nonempty": ensure_nonempty,
        "ensure_centered": True,
        "point_balance_radius": point_balance_radius * micron_scale,
    }

    data_sources = (tuple(
        (
            gp.ZarrSource(
                filename=str((sample / raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            DaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[matched],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            FilteredDaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[nonempty_placeholder],
                directed=True,
                node_attrs=["distance_to_fallback"],
                edge_attrs=[],
                num_nodes=max_nonempty_points,
                dist_attribute=distance_attr,
                min_dist=target_distance,
            ),
        ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) +
        FilterComponents(
            matched, node_offset[sample.name], centroid_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.int64),
        ) for sample in samples_path.iterdir()
        if sample.name in source_samples) + gp.RandomProvider())

    return (data_sources, raw, labels, nonempty_placeholder, matched)