示例#1
0
    def process(self, batch, request):

        outputs = gp.Batch()

        gt_graph = nx.Graph()
        mst_graph = nx.Graph()

        for block, block_specs in self.specs.items():
            ground_truth_key = block_specs["ground_truth"][0]
            mst_key = block_specs["mst_pred"][0]
            block_gt_graph = batch[ground_truth_key].to_nx_graph(
            ).to_undirected()
            block_mst_graph = batch[mst_key].to_nx_graph().to_undirected()
            gt_graph = nx.disjoint_union(gt_graph, block_gt_graph)
            mst_graph = nx.disjoint_union(mst_graph, block_mst_graph)

        for node, attrs in gt_graph.nodes.items():
            attrs["id"] = node
        for node, attrs in mst_graph.nodes.items():
            attrs["id"] = node

        outputs[self.gt] = gp.Graph.from_nx_graph(
            gt_graph,
            gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3),
                         directed=False))
        outputs[self.mst] = gp.Graph.from_nx_graph(
            mst_graph,
            gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3),
                         directed=False),
        )
        return outputs
示例#2
0
 def prepare(self, request):
     deps = gp.BatchRequest()
     deps[self.mst] = gp.GraphSpec(roi=self.roi)
     deps[self.gt] = gp.GraphSpec(roi=self.roi)
     if self.connectivity is not None:
         deps[self.connectivity] = gp.GraphSpec(roi=self.roi)
     return deps
示例#3
0
    def setup(self):

        # provide points in an infinite ROI
        self.graph_spec = gp.GraphSpec(
            roi=gp.Roi(offset=(0, ) * self.dims, shape=(None, ) * self.dims))

        self.provides(self.graph_key, self.graph_spec)
示例#4
0
    def setup(self):

        self.ndims = self.data.shape[1]

        if self.points_spec is not None:
            self.provides(self.points, self.points_spec)
        elif isinstance(self.points, gp.ArrayKey):
            self.provides(self.points, gp.ArraySpec(voxel_size=((1, ))))
        elif isinstance(self.points, gp.GraphKey):
            print(self.ndims)
            min_bb = gp.Coordinate(
                np.floor(np.amin(self.data[:, :self.ndims], 0)))
            max_bb = gp.Coordinate(
                np.ceil(np.amax(self.data[:, :self.ndims], 0)) + 1)

            roi = gp.Roi(min_bb, max_bb - min_bb)
            logger.debug(f"Bounding Box: {roi}")

            self.provides(self.points, gp.GraphSpec(roi=roi))

        if self.labels is not None:
            assert isinstance(self.labels, gp.ArrayKey), \
                   f"Label key must be an ArrayKey, \
                     was given {type(self.labels)}"

            if self.labels_spec is not None:
                self.provides(self.labels, self.labels_spec)
            else:
                self.provides(self.labels, gp.ArraySpec(voxel_size=((1, ))))
示例#5
0
 def setup(self):
     if str(self.snapshot_file).endswith(".h5") or str(self.snapshot_file).endswith(
         ".hdf"
     ):
         data = h5py.File(self.snapshot_file, "r")
     elif str(self.snapshot_file).endswith(".zarr"):
         data = zarr.open(self.snapshot_file, "r")
     for key, path in self.datasets.items():
         if isinstance(key, gp.ArrayKey):
             try:
                 x = data[path]
             except KeyError:
                 raise KeyError(f"Could not find {path}")
             spec = self.spec_from_dataset(x)
             self.provides(key, spec)
         elif isinstance(key, gp.GraphKey):
             try:
                 locations = data[f"{path}-locations"]
             except KeyError:
                 raise KeyError(f"Could not find {path}-locations")
             spec = gp.GraphSpec(
                 gp.Roi((None,) * len(locations[0]), (None,) * len(locations[0])),
                 directed=self.directed.get(key),
             )
             self.provides(key, spec)
示例#6
0
    def provide(self, request):

        roi = request[self.graph_key].roi

        random_points = self.random_point_generator.get_random_points(roi)

        batch = gp.Batch()
        batch[self.graph_key] = gp.Graph(
            [gp.Node(id=i, location=l) for i, l in random_points.items()], [],
            gp.GraphSpec(roi=roi, directed=False))

        return batch
示例#7
0
    def setup(self):
        self.enable_autoskip()
        all_rois = []
        for block, block_specs in self.specs.items():
            ground_truth = block_specs["ground_truth"]
            mst_pred = block_specs["mst_pred"]

            for key, spec in [ground_truth, mst_pred]:
                current_spec = self.spec[key].copy()
                current_spec.roi = spec.roi
                self.updates(key, current_spec)
                all_rois.append(current_spec.roi)

        self.total_roi = all_rois[0]
        for roi in all_rois[1:]:
            self.total_roi = self.total_roi.union(roi)
        self.provides(
            self.mst,
            gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3),
                         directed=False))
        self.provides(
            self.gt,
            gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3),
                         directed=False))
示例#8
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = gp.Batch()

        # If a Array is requested then we will randomly choose
        # the number of requested points
        if isinstance(self.points, gp.ArrayKey):
            points = np.random.choice(self.data.shape[0], self.num_points)
            data = self.data[points][np.newaxis]
            if self.scale is not None:
                data = data * self.scale
            if self.label_data is not None:
                labels = self.label_data[points]
            batch[self.points] = gp.Array(data, self.spec[self.points])

        else:
            # If a graph is request we must select points within the
            # request ROI

            min_bb = request[self.points].roi.get_begin()
            max_bb = request[self.points].roi.get_end()

            logger.debug("Points source got request for %s",
                         request[self.points].roi)

            point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
            for d in range(self.ndims):
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] >= min_bb[d])
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] < max_bb[d])

            points_data, labels = self._get_points(point_filter)
            logger.debug(f"Found {len(points_data)} points")
            points_spec = gp.GraphSpec(roi=request[self.points].roi.copy())
            batch.graphs[self.points] = gp.Graph(points_data, [], points_spec)

        # Labels will always be an Array
        if self.label_data is not None:
            batch[self.labels] = gp.Array(labels, self.spec[self.labels])

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
示例#9
0
    def graph_from_path(self, graph_key, data, path):
        saved_ids = data[f"{path}-ids"]
        saved_edges = data[f"{path}-edges"]
        saved_locations = data[f"{path}-locations"]
        node_attrs = [
            (attr, data[f"{path}/node_attrs/{attr}"])
            for attr in self.node_attrs.get(graph_key, [])
        ]
        attrs = [attr for attr, _ in node_attrs]
        attr_values = zip(
            *[values for _, values in node_attrs], (None,) * len(saved_locations)
        )
        nodes = [
            gp.Node(
                node_id,
                location=location,
                attrs={attr: value for attr, value in zip(attrs, values)},
            )
            for node_id, location, values in zip(
                saved_ids, saved_locations, attr_values
            )
        ]

        edge_attrs = [
            (attr, data[f"{path}/edge_attrs/{attr}"])
            for attr in self.edge_attrs.get(graph_key, [])
        ]
        attrs = [attr for attr, _ in edge_attrs]
        attr_values = zip(
            *[values for _, values in edge_attrs], (None,) * len(saved_edges)
        )
        edges = [
            gp.Edge(u, v, attrs={attr: value for attr, value in zip(attrs, values)})
            for (u, v), values in zip(saved_edges, attr_values)
        ]
        return gp.Graph(
            nodes,
            edges,
            gp.GraphSpec(
                gp.Roi(
                    (None,) * len(saved_locations[0]), (None,) * len(saved_locations[0])
                ),
                directed=self.directed.get(graph_key),
            ),
        )
示例#10
0
def get_requests(config, blocks, raw, emb_pred, labels, gt):
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape
    diff = input_size - output_size

    cube_rois = [get_cube_roi(config, block) for block in blocks]

    requests = []
    for cube_roi in cube_rois:
        context_roi = cube_roi.grow(diff // 2, diff // 2)
        request = gp.BatchRequest()
        request[raw] = gp.ArraySpec(roi=context_roi)
        request[emb_pred] = gp.ArraySpec(roi=cube_roi)
        request[labels] = gp.ArraySpec(roi=cube_roi)
        request[gt] = gp.GraphSpec(roi=cube_roi)
        requests.append(request)
    return requests
示例#11
0
def validation_data_sources_from_snapshots(config, blocks):
    validation_blocks = Path(config["VALIDATION_BLOCKS"])

    raw = gp.ArrayKey("RAW")
    ground_truth = gp.GraphKey("GROUND_TRUTH")
    labels = gp.ArrayKey("LABELS")

    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    block_pipelines = []
    for block in blocks:

        pipelines = (
            SnapshotSource(
                validation_blocks / f"block_{block}.hdf",
                {
                    labels: "volumes/labels",
                    ground_truth: "points/gt"
                },
                directed={ground_truth: True},
            ),
            SnapshotSource(validation_blocks / f"block_{block}.hdf",
                           {raw: "volumes/raw"}),
        )

        cube_roi = get_cube_roi(config, block)

        request = gp.BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        request[raw] = gp.ArraySpec(input_roi)
        request[ground_truth] = gp.GraphSpec(cube_roi)
        request[labels] = gp.ArraySpec(cube_roi)

        block_pipelines.append((pipelines, request))
    return block_pipelines, (raw, labels, ground_truth)
示例#12
0
def validation_data_sources_recomputed(config, blocks):
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    validation_dirs = {}
    for group in benchmark_datasets_path.iterdir():
        if "validation" in group.name and group.is_dir():
            for validation_dir in group.iterdir():
                validation_num = int(validation_dir.name.split("_")[-1])
                if validation_num in blocks:
                    validation_dirs[validation_num] = validation_dir

    validation_dirs = [validation_dirs[block] for block in blocks]

    raw = gp.ArrayKey("RAW")
    ground_truth = gp.GraphKey("GROUND_TRUTH")
    labels = gp.ArrayKey("LABELS")

    validation_pipelines = []
    for validation_dir in validation_dirs:
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        pipeline = ((
            gp.ZarrSource(
                filename=str(Path(sample_dir, sample, raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True,
                                      voxel_size=voxel_size)
                },
            ),
            nl.gunpowder.nodes.MouselightSwcFileSource(
                validation_dir,
                [ground_truth],
                transform_file=transform_template.format(sample=sample),
                ignore_human_nodes=False,
                scale=voxel_size,
                transpose=[2, 1, 0],
                points_spec=[
                    gp.PointsSpec(roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ))
                ],
            ),
        ) + gp.nodes.MergeProvider() + gp.Normalize(
            raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton(
                ground_truth,
                labels,
                connected_component_labeling=True,
                array_spec=gp.ArraySpec(
                    voxel_size=voxel_size,
                    dtype=np.int64,
                    roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ),
                ),
            ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000]))

        request = gp.BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        print(f"input_roi has shape: {input_roi.get_shape()}")
        print(f"cube_roi has shape: {cube_roi.get_shape()}")
        request[raw] = gp.ArraySpec(input_roi)
        request[ground_truth] = gp.GraphSpec(cube_roi)
        request[labels] = gp.ArraySpec(cube_roi)

        validation_pipelines.append((pipeline, request))
    return validation_pipelines, (raw, labels, ground_truth)
示例#13
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={
                raw: "volume-rechunked",
                raw_clahed: "volume-rechunked"
            },
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                raw_clahed:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
            },
        ) + gp.Normalize(raw, dtype=np.float32) +
                      gp.Normalize(raw_clahed, dtype=np.float32) +
                      scipyCLAHE([raw_clahed], [20, 64, 64]))
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)

        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec[raw_clahed] = gp.ArraySpec(input_roi)
        additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi)
        block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted)
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted)
        block_spec[labels] = gp.ArraySpec(cube_roi_shifted)
        additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted)

        pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() +
                    gp.SpecifiedLocation(locations=[cube_roi.get_center()]) +
                    gp.Crop(raw, roi=input_roi) +
                    gp.Crop(raw_clahed, roi=input_roi) +
                    gp.Crop(ground_truth, roi=cube_roi_shifted) +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot(
                        {
                            raw: f"volumes/{block}/raw",
                            raw_clahed: f"volumes/{block}/raw_clahe",
                            ground_truth: f"points/{block}/ground_truth",
                            labels: f"volumes/{block}/labels",
                        },
                        additional_request=additional_request,
                        output_dir="validations",
                        output_filename="validations.hdf",
                    ))

        validation_pipelines.append(pipeline)

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() + gp.PrintProfilingStats())
    return validation_pipeline, specs
示例#14
0
    def process(self, batch, request):
        num_thresholds = self.num_thresholds
        threshold_range = self.threshold_range

        outputs = gp.Batch()

        gt_graph = batch[self.gt].to_nx_graph().to_undirected()
        mst_graph = batch[self.mst].to_nx_graph().to_undirected()
        if self.connectivity is not None:
            connectivity_graph = batch[
                self.connectivity].to_nx_graph().to_undirected()

        # assert mst_graph.number_of_nodes() > 0, f"mst_graph is empty!"

        if self.details is not None:
            matching_details_graph = nx.Graph()
            if mst_graph.number_of_nodes() == 0:
                node_offset = max([node
                                   for node in mst_graph.nodes] + [-1]) + 1
                label_offset = len(list(
                    nx.connected_components(mst_graph))) + 1

                for node, attrs in mst_graph.nodes.items():
                    matching_details_graph.add_node(node,
                                                    **copy.deepcopy(attrs))
                for edge, attrs in mst_graph.edges.items():
                    matching_details_graph.add_edge(edge[0], edge[1],
                                                    **copy.deepcopy(attrs))
                for node, attrs in gt_graph.nodes.items():
                    matching_details_graph.add_node(node + node_offset,
                                                    **copy.deepcopy(attrs))
                    matching_details_graph.nodes[node + node_offset]["id"] = (
                        node + node_offset)
                for edge, attrs in gt_graph.edges.items():
                    matching_details_graph.add_edge(edge[0] + node_offset,
                                                    edge[1] + node_offset,
                                                    **copy.deepcopy(attrs))

        edges = [(edge, attrs[self.edge_threshold_attr])
                 for edge, attrs in mst_graph.edges.items()]
        edges = list(sorted(edges, key=lambda x: x[1]))
        edge_lens = [e[1] for e in edges]
        # min_threshold = edges[0][1]
        if len(edge_lens) > 0:
            min_threshold = edge_lens[int(len(edge_lens) * threshold_range[0])]
            max_threshold = edge_lens[int(len(edge_lens) * threshold_range[1])
                                      - 1]
        else:
            min_threshold = 0
            max_threshold = 1
        thresholds = np.linspace(min_threshold,
                                 max_threshold,
                                 num=num_thresholds)

        current_threshold_mst = nx.Graph()

        edge_deque = deque(edges)
        edit_distances = []
        split_costs = []
        merge_costs = []
        false_pos_costs = []
        false_neg_costs = []
        num_nodes = []
        num_edges = []

        best_score = None
        best_graph = None

        for threshold_index, threshold in enumerate(thresholds):
            logger.warning(f"Using threshold: {threshold}")
            while len(edge_deque) > 0 and edge_deque[0][1] <= threshold:
                (u, v), _ = edge_deque.popleft()
                attrs = mst_graph.edges[(u, v)]
                current_threshold_mst.add_edge(u, v)
                current_threshold_mst.add_node(u, **mst_graph.nodes[u])
                current_threshold_mst.add_node(v, **mst_graph.nodes[v])

            if self.connectivity is not None:
                temp = nx.Graph()
                next_node = max([node
                                 for node in connectivity_graph.nodes]) + 1
                for i, cc in enumerate(
                        nx.connected_components(current_threshold_mst)):
                    component_subgraph = current_threshold_mst.subgraph(cc)
                    for node in component_subgraph.nodes:
                        temp.add_node(node,
                                      **dict(connectivity_graph.nodes[node]))
                        temp.nodes[node]["component"] = i
                    for edge in connectivity_graph.edges:
                        if (edge[0] in temp.nodes and edge[1] in temp.nodes
                                and temp.nodes[edge[0]]["component"]
                                == temp.nodes[edge[1]]["component"]):
                            temp.add_edge(
                                edge[0], edge[1],
                                **dict(connectivity_graph.edges[edge]))
                        elif False:
                            path = nx.shortest_path(connectivity_graph,
                                                    edge[0], edge[1])
                            cloned_path = []
                            for node in path:
                                if node in temp.nodes:
                                    cloned_path.append(node)
                                else:
                                    cloned_path.append(next_node)
                                    next_node += 1
                            path_len = len(cloned_path) - 1
                            for i, j in zip(range(path_len),
                                            range(1, path_len + 1)):
                                u = cloned_path[i]
                                if u not in temp.nodes:
                                    temp.add_node(
                                        u,
                                        **dict(
                                            connectivity_graph.nodes[path[i]]))
                                v = cloned_path[j]
                                if v not in temp.nodes:
                                    temp.add_node(
                                        v,
                                        **dict(
                                            connectivity_graph.nodes[path[j]]))

                                temp.add_edge(
                                    u,
                                    v,
                                    **dict(connectivity_graph.edges[path[i],
                                                                    path[j]]),
                                )

            else:
                temp = copy.deepcopy(current_threshold_mst)
                for i, cc in enumerate(nx.connected_components(temp)):
                    for node in cc:
                        attrs = temp.nodes[node]
                        attrs["component"] = i

            # remove small connected_components
            false_pos_nodes = []
            for cc in nx.connected_components(temp):
                cc_graph = temp.subgraph(cc)
                min_loc = None
                max_loc = None
                for node, attrs in cc_graph.nodes.items():
                    node_loc = attrs[self.location_attr]
                    if min_loc is None:
                        min_loc = node_loc
                    else:
                        min_loc = np.min(np.array([node_loc, min_loc]), axis=0)
                    if max_loc is None:
                        max_loc = node_loc
                    else:
                        max_loc = np.max(np.array([node_loc, max_loc]), axis=0)
                if np.linalg.norm(min_loc -
                                  max_loc) < self.small_component_threshold:
                    false_pos_nodes += list(cc)
            for node in false_pos_nodes:
                temp.remove_node(node)

            nodes_x = list(temp.nodes)
            nodes_y = list(gt_graph.nodes)

            node_labels_x = {
                node: attrs["component"]
                for node, attrs in temp.nodes.items()
            }

            node_labels_y = {
                node: component
                for component, nodes in enumerate(
                    nx.connected_components(gt_graph)) for node in nodes
            }

            edges_yx = get_edges_xy(
                gt_graph,
                temp,
                location_attr=self.location_attr,
                node_match_threshold=self.comatch_threshold,
            )
            edges_xy = [(v, u) for u, v in edges_yx]

            result = match_components(
                nodes_x,
                nodes_y,
                edges_xy,
                copy.deepcopy(node_labels_x),
                copy.deepcopy(node_labels_y),
            )

            if self.details is not None:
                # add a match details graph to the batch
                # details is a graph containing nodes from both mst and gt
                # to access details of a node, use `details.nodes[node]["details"]`
                # where the details returned are a numpy array of shape (num_thresholds, _).
                # the _ values stored per threshold are success, fp, fn, merge, split, selected, mst, gt

                # NODES
                # success, mst: n matches to only nodes of 1 label, which matches its own label
                # success, gt: n matches to only nodes of 1 label, which matches its own label
                # fp: n in mst matches to nothing
                # fn: n in gt matches to nothing
                # merge: n in mst matches to a node with label not matching its own
                # split: n in gt matches to a node with label not matching its own
                # selected: n in mst in thresholded graph

                # EDGES
                # success, mst: both endpoints successful
                # success, gt: both endpoints successful
                # fp: both endpoints fp
                # fn: both endpoints fn
                # merge: e in mst: only one endpoint successful
                # split: e in gt: only one endpoint successful
                # selected: e in mst in thresholded graph
                (label_matches, node_matches, splits, merges, fps,
                 fns) = result
                # create lookup tables:
                x_label_match_lut = {}
                y_label_match_lut = {}
                for a, b in label_matches:
                    x_matches = x_label_match_lut.setdefault(a, set())
                    x_matches.add(b)
                    y_matches = y_label_match_lut.setdefault(b, set())
                    y_matches.add(a)
                x_node_match_lut = {}
                y_node_match_lut = {}
                for a, b in node_matches:
                    x_matches = x_node_match_lut.setdefault(a, set())
                    x_matches.add(b)
                    y_matches = y_node_match_lut.setdefault(b, set())
                    y_matches.add(a)

                for node, attrs in matching_details_graph.nodes.items():
                    gt = int(node >= node_offset)
                    mst = 1 - gt

                    if gt == 1:
                        node = node - node_offset

                    selected = gt or (node in temp.nodes())

                    if selected:
                        success, fp, fn, merge, split, label_pair = self.node_matching_result(
                            node,
                            gt,
                            x_label_match_lut,
                            y_label_match_lut,
                            x_node_match_lut,
                            y_node_match_lut,
                            node_labels_x,
                            node_labels_y,
                        )
                    else:
                        success, fp, fn, merge, split, label_pair = (
                            0,
                            0,
                            0,
                            0,
                            0,
                            (-1, -1),
                        )

                    data = attrs.setdefault(
                        "details", np.zeros((len(thresholds), 7), dtype=bool))
                    data[threshold_index] = [
                        selected,
                        success,
                        fp,
                        fn,
                        merge,
                        split,
                        gt,
                    ]
                    label_pairs = attrs.setdefault("label_pair", [])
                    label_pairs.append(label_pair)
                    assert len(label_pairs) == threshold_index + 1
                for (u, v), attrs in matching_details_graph.edges.items():
                    (
                        u_selected,
                        u_success,
                        u_fp,
                        u_fn,
                        u_merge,
                        u_split,
                        u_gt,
                    ) = matching_details_graph.nodes[u]["details"][
                        threshold_index]

                    (
                        v_selected,
                        v_success,
                        v_fp,
                        v_fn,
                        v_merge,
                        v_split,
                        v_gt,
                    ) = matching_details_graph.nodes[v]["details"][
                        threshold_index]

                    assert u_gt == v_gt
                    e_gt = u_gt

                    u_label_pair = matching_details_graph.nodes[u][
                        "label_pair"][threshold_index]
                    v_label_pair = matching_details_graph.nodes[v][
                        "label_pair"][threshold_index]

                    e_selected = u_selected and v_selected
                    e_success = (e_selected and u_success and v_success
                                 and (u_label_pair == v_label_pair))
                    e_fp = u_fp and v_fp
                    e_fn = u_fn and v_fn
                    e_merge = e_selected and (not e_success) and (
                        not e_fp) and not e_gt
                    e_split = e_selected and (not e_success) and (
                        not e_fn) and e_gt
                    assert not (e_success and e_merge)
                    assert not (e_success and e_split)

                    data = attrs.setdefault(
                        "details", np.zeros((len(thresholds), 7), dtype=bool))
                    if e_success:
                        label_pairs = attrs.setdefault("label_pair", [])
                        label_pairs.append(u_label_pair)
                        assert len(label_pairs) == threshold_index + 1
                    else:
                        label_pairs = attrs.setdefault("label_pair", [])
                        label_pairs.append((-1, -1))
                        assert len(label_pairs) == threshold_index + 1
                    data[threshold_index] = [
                        e_selected,
                        e_success,
                        e_fp,
                        e_fn,
                        e_merge,
                        e_split,
                        e_gt,
                    ]

            edit_distance, (
                split_cost,
                merge_cost,
                false_pos_cost,
                false_neg_cost,
            ) = psudo_graph_edit_distance(
                result[1],
                node_labels_x,
                node_labels_y,
                temp,
                gt_graph,
                self.location_attr,
                node_spacing=self.edit_distance_node_spacing,
                details=True,
            )

            edit_distances.append(edit_distance)
            split_costs.append(split_cost)
            merge_costs.append(merge_cost)
            false_pos_costs.append(false_pos_cost)
            false_neg_costs.append(false_neg_cost)
            num_nodes.append(len(temp.nodes))
            num_edges.append(len(temp.edges))

            # save the best version:
            if best_score is None:
                best_score = edit_distance
                best_graph = copy.deepcopy(temp)
            elif edit_distance < best_score:
                best_score = edit_distance
                best_graph = copy.deepcopy(temp)

        outputs[self.output] = gp.Array(
            np.array([
                edit_distances,
                thresholds,
                num_nodes,
                num_edges,
                split_costs,
                merge_costs,
                false_pos_costs,
                false_neg_costs,
            ]),
            gp.ArraySpec(nonspatial=True),
        )
        if self.output_graph is not None:
            outputs[self.output_graph] = gp.Graph.from_nx_graph(
                best_graph,
                gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False))
        if self.details is not None:
            outputs[self.details] = gp.Graph.from_nx_graph(
                matching_details_graph,
                gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False),
            )
        return outputs
示例#15
0
 def setup(self):
     spec = gp.GraphSpec(roi=self.spec[self.array_key].roi)
     self.provides(self.graph_key, spec)
示例#16
0
 def setup(self):
     spec = gp.GraphSpec(roi=self.spec[self.array_key].roi)
     self.provides(self.graph_key, spec)
     self.center = np.array(self.spec[self.array_key].roi.get_center())
示例#17
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]
    candidate_threshold = config["NMS_THRESHOLD"]
    candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale
    coordinate_scale = config["COORDINATE_SCALE"] * np.array(
        voxel_size) / micron_scale

    emb_model = get_emb_model(config)
    fg_model = get_fg_model(config)

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")
        candidates = gp.ArrayKey(f"CANDIDATES_{block}")
        mst = gp.GraphKey(f"MST_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={raw: "volume-rechunked"},
            array_specs={
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size)
            },
        ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64]))
        emb_source, emb = add_emb_pred(config, raw_source, raw, block,
                                       emb_model)
        pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model)
        pred_source = add_scan(pred_source, {
            raw: input_size,
            emb: output_size,
            fg: output_size
        })
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        block_spec = specs.setdefault(block, {})
        block_spec["raw"] = (raw, gp.ArraySpec(input_roi))
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi))
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi)
        block_spec["labels"] = (labels, gp.ArraySpec(cube_roi))
        additional_request[labels] = gp.ArraySpec(roi=cube_roi)
        block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi))
        additional_request[fg] = gp.ArraySpec(roi=cube_roi)
        block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi))
        additional_request[emb] = gp.ArraySpec(roi=cube_roi)
        block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi))
        additional_request[candidates] = gp.ArraySpec(roi=cube_roi)
        block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi))
        additional_request[mst] = gp.GraphSpec(roi=cube_roi)

        pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    Skeletonize(fg, candidates, candidate_spacing,
                                candidate_threshold) + EMST(
                                    emb,
                                    candidates,
                                    mst,
                                    distance_attr=distance_attr,
                                    coordinate_scale=coordinate_scale,
                                ) + gp.Snapshot(
                                    {
                                        raw: f"volumes/{raw}",
                                        ground_truth: f"points/{ground_truth}",
                                        labels: f"volumes/{labels}",
                                        fg: f"volumes/{fg}",
                                        emb: f"volumes/{emb}",
                                        candidates: f"volumes/{candidates}",
                                        mst: f"points/{mst}",
                                    },
                                    additional_request=additional_request,
                                    output_dir="snapshots",
                                    output_filename="{id}.hdf",
                                    edge_attrs={mst: [distance_attr]},
                                ))

        validation_pipelines.append(pipeline)

    full_gt = gp.GraphKey("FULL_GT")
    full_mst = gp.GraphKey("FULL_MST")
    score = gp.ArrayKey("SCORE")

    validation_pipeline = (
        tuple(pipeline for pipeline in validation_pipelines) +
        gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) +
        Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) +
        gp.PrintProfilingStats())
    return validation_pipeline, score
示例#18
0
def emb_validation_pipeline(
    config,
    snapshot_file,
    candidates_path,
    raw_path,
    gt_path,
    candidates_mst_path=None,
    candidates_mst_dense_path=None,
    path_stat="max",
):
    checkpoint = config["EMB_EVAL_CHECKPOINT"]
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]
    coordinate_scale = config["COORDINATE_SCALE"] * np.array(
        voxel_size) / micron_scale
    num_thresholds = config["NUM_EVAL_THRESHOLDS"]
    threshold_range = config["EVAL_THRESHOLD_RANGE"]

    edge_threshold_0 = config["EVAL_EDGE_THRESHOLD_0"]
    component_threshold_0 = config["COMPONENT_THRESHOLD_0"]
    component_threshold_1 = config["COMPONENT_THRESHOLD_1"]

    clip_limit = config["CLAHE_CLIP_LIMIT"]
    normalize = config["CLAHE_NORMALIZE"]

    validation_pipelines = []
    specs = {}

    emb_model = get_emb_model(config)
    emb_model.eval()

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array(voxel_size[::-1]),
        )

        candidates_1 = gp.ArrayKey(f"CANDIDATES_1_{block}")

        raw = gp.ArrayKey(f"RAW_{block}")
        mst_0 = gp.GraphKey(f"MST_0_{block}")
        mst_dense_0 = gp.GraphKey(f"MST_DENSE_0_{block}")
        mst_1 = gp.GraphKey(f"MST_1_{block}")
        mst_dense_1 = gp.GraphKey(f"MST_DENSE_1_{block}")
        mst_2 = gp.GraphKey(f"MST_2_{block}")
        mst_dense_2 = gp.GraphKey(f"MST_DENSE_2_{block}")
        gt = gp.GraphKey(f"GT_{block}")
        score = gp.ArrayKey(f"SCORE_{block}")
        details = gp.GraphKey(f"DETAILS_{block}")
        optimal_mst = gp.GraphKey(f"OPTIMAL_MST_{block}")

        # Volume Source
        raw_source = SnapshotSource(
            snapshot_file,
            datasets={
                raw: raw_path.format(block=block),
                candidates_1: candidates_path.format(block=block),
            },
        )

        # Graph Source
        graph_datasets = {gt: gt_path.format(block=block)}
        graph_directionality = {gt: False}
        edge_attrs = {}
        if candidates_mst_path is not None:
            graph_datasets[mst_0] = candidates_mst_path.format(block=block)
            graph_directionality[mst_0] = False
            edge_attrs[mst_0] = [distance_attr]
        if candidates_mst_dense_path is not None:
            graph_datasets[mst_dense_0] = candidates_mst_dense_path.format(
                block=block)
            graph_directionality[mst_dense_0] = False
            edge_attrs[mst_dense_0] = [distance_attr]
        gt_source = SnapshotSource(
            snapshot_file,
            datasets=graph_datasets,
            directed=graph_directionality,
            edge_attrs=edge_attrs,
        )

        if config["EVAL_CLAHE"]:
            raw_source = raw_source + scipyCLAHE(
                [raw],
                gp.Coordinate([20, 64, 64]) * voxel_size,
                clip_limit=clip_limit,
                normalize=normalize,
            )
        else:
            pass

        emb_source, emb, neighborhood = add_emb_pred(config, raw_source, raw,
                                                     block, emb_model)

        reference_sizes = {
            raw: input_size,
            emb: output_size,
            candidates_1: output_size
        }
        if neighborhood is not None:
            reference_sizes[neighborhood] = output_size

        emb_source = add_scan(emb_source, reference_sizes)

        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        block_spec[candidates_1] = gp.ArraySpec(cube_roi_shifted)
        block_spec[emb] = gp.ArraySpec(cube_roi_shifted)
        if neighborhood is not None:
            block_spec[neighborhood] = gp.ArraySpec(cube_roi_shifted)
        block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst_dense_0] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)
        block_spec[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst_dense_1] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)
        block_spec[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False)
        # block_spec[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[score] = gp.ArraySpec(nonspatial=True)
        block_spec[optimal_mst] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)

        additional_request = BatchRequest()
        additional_request[raw] = gp.ArraySpec(input_roi)
        additional_request[candidates_1] = gp.ArraySpec(cube_roi_shifted)
        additional_request[emb] = gp.ArraySpec(cube_roi_shifted)
        if neighborhood is not None:
            additional_request[neighborhood] = gp.ArraySpec(cube_roi_shifted)
        additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        additional_request[mst_0] = gp.GraphSpec(cube_roi_shifted,
                                                 directed=False)
        additional_request[mst_dense_0] = gp.GraphSpec(cube_roi_shifted,
                                                       directed=False)
        additional_request[mst_1] = gp.GraphSpec(cube_roi_shifted,
                                                 directed=False)
        additional_request[mst_dense_1] = gp.GraphSpec(cube_roi_shifted,
                                                       directed=False)
        additional_request[mst_2] = gp.GraphSpec(cube_roi_shifted,
                                                 directed=False)
        # additional_request[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False)
        additional_request[details] = gp.GraphSpec(cube_roi_shifted,
                                                   directed=False)
        additional_request[optimal_mst] = gp.GraphSpec(cube_roi_shifted,
                                                       directed=False)

        pipeline = (emb_source, gt_source) + gp.MergeProvider()

        if candidates_mst_path is not None and candidates_mst_dense_path is not None:
            # mst_0 provided, just need to calculate distances.
            pass
        elif config["EVAL_MINIMAX_EMBEDDING_DIST"]:
            # No mst_0 provided, must first calculate mst_0 and dense mst_0
            pipeline += MiniMaxEmbeddings(
                emb,
                candidates_1,
                decimated=mst_0,
                dense=mst_dense_0,
                distance_attr=distance_attr,
            )

        else:
            # mst/mst_dense not provided. Simply use euclidean distance on candidates
            pipeline += EMST(
                emb,
                candidates_1,
                mst_0,
                distance_attr=distance_attr,
                coordinate_scale=coordinate_scale,
            )
            pipeline += EMST(
                emb,
                candidates_1,
                mst_dense_0,
                distance_attr=distance_attr,
                coordinate_scale=coordinate_scale,
            )

        pipeline += ThresholdEdges(
            (mst_0, mst_1),
            edge_threshold_0,
            component_threshold_0,
            msts_dense=(mst_dense_0, mst_dense_1),
            distance_attr=distance_attr,
        )

        pipeline += ComponentWiseEMST(
            emb,
            mst_1,
            mst_2,
            distance_attr=distance_attr,
            coordinate_scale=coordinate_scale,
        )

        # pipeline += ScoreEdges(
        #     mst, mst_dense, emb, distance_attr=distance_attr, path_stat=path_stat
        # )

        pipeline += Evaluate(
            gt,
            mst_2,
            score,
            roi=cube_roi_shifted,
            details=details,
            edge_threshold_attr=distance_attr,
            num_thresholds=num_thresholds,
            threshold_range=threshold_range,
            small_component_threshold=component_threshold_1,
            # connectivity=mst_1,
            output_graph=optimal_mst,
        )

        if config["EVAL_SNAPSHOT"]:
            snapshot_datasets = {
                raw: f"volumes/raw",
                emb: f"volumes/embeddings",
                candidates_1: f"volumes/candidates_1",
                mst_0: f"points/mst_0",
                mst_dense_0: f"points/mst_dense_0",
                mst_1: f"points/mst_1",
                mst_dense_1: f"points/mst_dense_1",
                # mst_2: f"points/mst_2",
                gt: f"points/gt",
                details: f"points/details",
                optimal_mst: f"points/optimal_mst",
            }
            if neighborhood is not None:
                snapshot_datasets[neighborhood] = f"volumes/neighborhood"
            pipeline += gp.Snapshot(
                snapshot_datasets,
                output_dir=config["EVAL_SNAPSHOT_DIR"],
                output_filename=config["EVAL_SNAPSHOT_NAME"].format(
                    checkpoint=checkpoint,
                    block=block,
                    coordinate_scale=",".join(
                        [str(x) for x in coordinate_scale]),
                ),
                edge_attrs={
                    mst_0: [distance_attr],
                    mst_dense_0: [distance_attr],
                    mst_1: [distance_attr],
                    mst_dense_1: [distance_attr],
                    # mst_2: [distance_attr],
                    # optimal_mst: [distance_attr], # it is unclear how to add distances if using connectivity graph
                    # mst_dense_2: [distance_attr],
                    details: ["details", "label_pair"],
                },
                node_attrs={details: ["details", "label_pair"]},
                additional_request=additional_request,
            )

        validation_pipelines.append(pipeline)

    final_score = gp.ArrayKey("SCORE")

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() +
                           MergeScores(final_score, specs) +
                           gp.PrintProfilingStats())
    return validation_pipeline, final_score
示例#19
0
def pre_computed_fg_validation_pipeline(config, snapshot_file, raw_path,
                                        gt_path, fg_path):
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    candidate_spacing = config["CANDIDATE_SPACING"]
    candidate_threshold = config["CANDIDATE_THRESHOLD"]

    distance_attr = config["DISTANCE_ATTR"]
    num_thresholds = config["NUM_EVAL_THRESHOLDS"]
    threshold_range = config["EVAL_THRESHOLD_RANGE"]

    component_threshold = config["COMPONENT_THRESHOLD_1"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array(voxel_size[::-1]),
        )

        candidates = gp.ArrayKey(f"CANDIDATES_{block}")
        raw = gp.ArrayKey(f"RAW_{block}")
        mst = gp.GraphKey(f"MST_{block}")
        gt = gp.GraphKey(f"GT_{block}")
        fg = gp.ArrayKey(f"FG_{block}")
        score = gp.ArrayKey(f"SCORE_{block}")
        details = gp.GraphKey(f"DETAILS_{block}")

        raw_source = SnapshotSource(
            snapshot_file,
            datasets={
                raw: raw_path.format(block=block),
                fg: fg_path.format(block=block),
            },
        )
        gt_source = SnapshotSource(
            snapshot_file,
            datasets={gt: gt_path.format(block=block)},
            directed={gt: False},
        )

        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        block_spec[candidates] = gp.ArraySpec(cube_roi_shifted)
        block_spec[fg] = gp.ArraySpec(cube_roi_shifted)
        block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[score] = gp.ArraySpec(nonspatial=True)

        additional_request = BatchRequest()
        additional_request[raw] = gp.ArraySpec(input_roi)
        additional_request[candidates] = gp.ArraySpec(cube_roi_shifted)
        additional_request[fg] = gp.ArraySpec(cube_roi_shifted)
        additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        additional_request[mst] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)
        additional_request[details] = gp.GraphSpec(cube_roi_shifted,
                                                   directed=False)

        pipeline = ((raw_source, gt_source) + gp.MergeProvider() + Skeletonize(
            fg, candidates, candidate_spacing, candidate_threshold) +
                    MiniMax(fg, candidates, mst, distance_attr=distance_attr))

        pipeline += Evaluate(
            gt,
            mst,
            score,
            roi=cube_roi_shifted,
            details=details,
            edge_threshold_attr=distance_attr,
            num_thresholds=num_thresholds,
            threshold_range=threshold_range,
            small_component_threshold=component_threshold,
        )

        if config["EVAL_SNAPSHOT"]:
            pipeline += gp.Snapshot(
                {
                    raw: f"volumes/raw",
                    fg: f"volumes/foreground",
                    candidates: f"volumes/candidates",
                    mst: f"points/mst",
                    gt: f"points/gt",
                    details: f"points/details",
                },
                output_dir="eval_results",
                output_filename=config["EVAL_SNAPSHOT_NAME"].format(
                    block=block),
                edge_attrs={
                    mst: [distance_attr],
                    details: ["details", "label_pair"]
                },
                node_attrs={details: ["details", "label_pair"]},
                additional_request=additional_request,
            )

        validation_pipelines.append(pipeline)

    final_score = gp.ArrayKey("SCORE")

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() +
                           MergeScores(final_score, specs) +
                           gp.PrintProfilingStats())
    return validation_pipeline, final_score