Beispiel #1
0
    def process(self, batch, request):

        outputs = gp.Batch()

        gt_graph = nx.Graph()
        mst_graph = nx.Graph()

        for block, block_specs in self.specs.items():
            ground_truth_key = block_specs["ground_truth"][0]
            mst_key = block_specs["mst_pred"][0]
            block_gt_graph = batch[ground_truth_key].to_nx_graph(
            ).to_undirected()
            block_mst_graph = batch[mst_key].to_nx_graph().to_undirected()
            gt_graph = nx.disjoint_union(gt_graph, block_gt_graph)
            mst_graph = nx.disjoint_union(mst_graph, block_mst_graph)

        for node, attrs in gt_graph.nodes.items():
            attrs["id"] = node
        for node, attrs in mst_graph.nodes.items():
            attrs["id"] = node

        outputs[self.gt] = gp.Graph.from_nx_graph(
            gt_graph,
            gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3),
                         directed=False))
        outputs[self.mst] = gp.Graph.from_nx_graph(
            mst_graph,
            gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3),
                         directed=False),
        )
        return outputs
Beispiel #2
0
 def process(self, batch, request):
     outputs = gp.Batch()
     outputs[self.array] = copy.deepcopy(batch[self.array])
     outputs[self.array].data = (
         torch.from_numpy(batch[self.array].data).squeeze(0).numpy()
     )
     return outputs
Beispiel #3
0
 def process(self, batch, request):
     outputs = gp.Batch()
     for array in self.arrays:
         if array in batch:
             outputs[array] = copy.deepcopy(batch[array])
             outputs[array].data = torch.from_numpy(batch[array].data).squeeze(0).numpy()
     return outputs
Beispiel #4
0
    def provide(self, request):
        outputs = gp.Batch()

        # RAW
        raw_spec = copy.deepcopy(self.array_spec_raw)
        raw_spec.roi = request[self.raw].roi

        raw_shape = request[self.raw].roi.get_shape() / self.voxel_size

        outputs[self.raw] = gp.Array(
            np.random.randint(0, 256, raw_shape, dtype=raw_spec.dtype),
            raw_spec)

        # Unsqueeze
        outputs[self.raw].data = np.expand_dims(outputs[self.raw].data, axis=0)
        outputs[self.raw].data = np.expand_dims(outputs[self.raw].data, axis=0)

        # LABELS
        labels_spec = copy.deepcopy(self.array_spec_labels)
        labels_spec.roi = request[self.labels].roi

        labels_shape = request[self.labels].roi.get_shape() / self.voxel_size

        labels = np.ones(labels_shape, dtype=labels_spec.dtype)
        outputs[self.labels] = gp.Array(labels, labels_spec)

        # Unsqueeze
        outputs[self.labels].data = np.expand_dims(outputs[self.labels].data,
                                                   axis=0)

        return outputs
Beispiel #5
0
    def provide(self, request):

        timing = gp.profiling.Timing(self)
        timing.start()

        batch = gp.Batch()

        for (array_key, request_spec) in request.array_specs.items():
            logger.debug("Reading %s in %s...", array_key, request_spec.roi)

            voxel_size = self.spec[array_key].voxel_size

            # scale request roi to voxel units
            dataset_roi = request_spec.roi / voxel_size

            # shift request roi into dataset
            # dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset() / voxel_size

            # create array spec
            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi

            # add array to batch
            batch.arrays[array_key] = gp.Array(
                self.func(dataset_roi.get_shape()), array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
    def provide(self, request):

        batch = gp.Batch()
        for (array_key, request_spec) in request.array_specs.items():
            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi
            print "array_spec: ", array_spec.roi.get_shape()
            data = np.zeros((array_spec.roi.get_shape()))
            batch.arrays[array_key] = gp.Array(data, array_spec)
        return batch
Beispiel #7
0
    def provide(self, request):

        voxel_size = self.spec[self.raw].voxel_size
        shape = gp.Coordinate((1, ) + request[self.raw].roi.get_shape())

        noise = np.abs(np.random.randn(*shape))
        smoothed_noise = gaussian_filter(noise, sigma=self.smoothness)

        seeds = np.zeros(shape, dtype=int)
        for i in range(self.n_objects):
            if i == 0:
                num_points = 100
            else:
                num_points = self.points_per_skeleton
            points = np.stack(
                [
                    np.random.randint(0, shape[dim], num_points)
                    for dim in range(3)
                ],
                axis=1,
            )
            tree = skelerator.Tree(points)
            skeleton = skelerator.Skeleton(tree, [1, 1, 1],
                                           "linear",
                                           generate_graph=False)
            seeds = skeleton.draw(seeds, np.array([0, 0, 0]), i + 1)

        seeds[maximum_filter(seeds, size=4) != seeds] = 0
        seeds_dt = distance_transform_edt(seeds == 0) + 5.0 * smoothed_noise
        gt_data = cwatershed(seeds_dt, seeds).astype(np.uint64)[0] - 1

        labels = np.unique(gt_data)

        raw_data = np.zeros_like(gt_data, dtype=np.uint8)
        value = 0
        for label in labels:
            raw_data[gt_data == label] = value
            value += 255.0 / self.n_objects

        spec = request[self.raw].copy()
        spec.voxel_size = (1, 1)
        raw = gp.Array(raw_data, spec)

        spec = request[self.gt].copy()
        spec.voxel_size = (1, 1)
        gt_crop = (request[self.gt].roi -
                   request[self.raw].roi.get_begin()) / voxel_size
        gt_crop = gt_crop.to_slices()
        gt = gp.Array(gt_data[gt_crop], spec)

        batch = gp.Batch()
        batch[self.raw] = raw
        batch[self.gt] = gt

        return batch
Beispiel #8
0
    def provide(self, request):

        roi = request[self.graph_key].roi

        random_points = self.random_point_generator.get_random_points(roi)

        batch = gp.Batch()
        batch[self.graph_key] = gp.Graph(
            [gp.Node(id=i, location=l) for i, l in random_points.items()], [],
            gp.GraphSpec(roi=roi, directed=False))

        return batch
Beispiel #9
0
    def provide(self, request):
        roi_array = request[gp.ArrayKeys.M_PRED].roi
        batch = gp.Batch()
        batch.arrays[gp.ArrayKeys.M_PRED] = gp.Array(
            self.m_pred[(roi_array / self.voxel_size).to_slices()],
            spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size))
        slices = (roi_array / self.voxel_size).to_slices()
        batch.arrays[gp.ArrayKeys.D_PRED] = gp.Array(
            self.d_pred[:, slices[0], slices[1], slices[2]],
            spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size))

        return batch
Beispiel #10
0
    def process(self, batch, request):
        outputs = gp.Batch()

        if self.in_array not in batch:
            return

        data = batch[self.in_array].data
        spec = batch[self.in_array].spec.copy()
        spec.dtype = np.bool
        binarized = data != self.target
        outputs[self.out_array] = gp.Array(binarized, spec)

        return outputs
 def process(self, batch, request):
     final_scores = {}
     for key, array in batch.items():
         if "SCORE" in str(key):
             block = int(str(key).split("_")[1])
             final_scores[block] = array.data
     final_scores = [
         final_scores[block] for block in range(1, 26)
         if block in final_scores
     ]
     outputs = gp.Batch()
     outputs[self.output] = gp.Array(np.array(final_scores),
                                     gp.ArraySpec(nonspatial=True))
     return outputs
Beispiel #12
0
    def process(self, batch, request):

        array = batch.arrays[self.array]

        array.data = filters.gaussian(array.data,
                                      sigma=self.sigma,
                                      mode='constant',
                                      preserve_range=True,
                                      multichannel=False)

        batch = gp.Batch()
        batch[self.array] = array.crop(request[self.array].roi)

        return batch
Beispiel #13
0
    def process(self, batch, request):
        outputs = gp.Batch()

        # logger.debug("upsampeling %s with %s", self.source, self.factor)

        # resize
        data = batch.arrays[self.source].data
        data = rescale(data, self.factor)

        # create output array
        spec = self.spec[self.target].copy()
        spec.roi = request[self.target].roi
        outputs.arrays[self.target] = gp.Array(data, spec)

        return outputs
Beispiel #14
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = gp.Batch()

        # If a Array is requested then we will randomly choose
        # the number of requested points
        if isinstance(self.points, gp.ArrayKey):
            points = np.random.choice(self.data.shape[0], self.num_points)
            data = self.data[points][np.newaxis]
            if self.scale is not None:
                data = data * self.scale
            if self.label_data is not None:
                labels = self.label_data[points]
            batch[self.points] = gp.Array(data, self.spec[self.points])

        else:
            # If a graph is request we must select points within the
            # request ROI

            min_bb = request[self.points].roi.get_begin()
            max_bb = request[self.points].roi.get_end()

            logger.debug("Points source got request for %s",
                         request[self.points].roi)

            point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
            for d in range(self.ndims):
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] >= min_bb[d])
                point_filter = np.logical_and(point_filter,
                                              self.data[:, d] < max_bb[d])

            points_data, labels = self._get_points(point_filter)
            logger.debug(f"Found {len(points_data)} points")
            points_spec = gp.GraphSpec(roi=request[self.points].roi.copy())
            batch.graphs[self.points] = gp.Graph(points_data, [], points_spec)

        # Labels will always be an Array
        if self.label_data is not None:
            batch[self.labels] = gp.Array(labels, self.spec[self.labels])

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Beispiel #15
0
    def process(self, batch, request):
        output = gp.Batch()

        gt_array = NumpyArray.from_gp_array(batch[self.gt_key])
        target_array = self.predictor.create_target(gt_array)
        mask_array = NumpyArray.from_gp_array(batch[self.mask_key])
        weight_array = self.predictor.create_weight(
            gt_array, target_array, mask=mask_array
        )

        request_spec = request[self.target_key]
        request_spec.voxel_size = gt_array.voxel_size
        output[self.target_key] = gp.Array(target_array[request_spec.roi], request_spec)
        request_spec = request[self.weights_key]
        request_spec.voxel_size = gt_array.voxel_size
        output[self.weights_key] = gp.Array(
            weight_array[request_spec.roi], request_spec
        )
        return output
Beispiel #16
0
    def provide(self, request):

        voxel_size = self.spec[self.raw].voxel_size
        shape = gp.Coordinate((1, ) + request[self.raw].roi.get_shape())

        gt_data = np.zeros(shape, dtype=int)
        for i in range(self.n_objects):
            points = np.stack(
                [np.random.randint(0, shape[dim], 2) for dim in range(3)],
                axis=1)
            tree = skelerator.Tree(points)
            skeleton = skelerator.Skeleton(tree, [1, 1, 1],
                                           "linear",
                                           generate_graph=False)
            gt_data = skeleton.draw(gt_data, np.array([0, 0, 0]), i + 1)

        gt_data = gt_data[0].astype(np.uint64)
        gt_data = maximum_filter(gt_data, size=2)

        labels = np.unique(gt_data)

        raw_data = (gt_data > 0).astype(np.float32)
        raw_data = np.clip(
            raw_data + np.random.normal(scale=0.1, size=raw_data.shape), 0,
            1).astype(np.float32)

        spec = request[self.raw].copy()
        spec.voxel_size = (1, 1)
        raw = gp.Array(raw_data, spec)

        spec = request[self.gt].copy()
        spec.voxel_size = (1, 1)
        gt_crop = (request[self.gt].roi -
                   request[self.raw].roi.get_begin()) / voxel_size
        gt_crop = gt_crop.to_slices()
        gt = gp.Array(gt_data[gt_crop], spec)

        batch = gp.Batch()
        batch[self.raw] = raw
        batch[self.gt] = gt

        return batch
Beispiel #17
0
    def provide(self, request):
        batch = gp.Batch()
        # print "n:", self.n
        # print "pid: ", mp.current_process().pid

        for (array_key, request_spec) in request.array_specs.items():

            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi
            shape = array_spec.roi.get_shape()

            # enlarge
            lshape = list(shape)
            inc = [0] * len(shape)
            for i, s in enumerate(shape):
                if s % 2 != 0:
                    inc[i] += 1
                    lshape[i] += 1
            shape = gp.Coordinate(lshape)

            data = create_segmentation(
                shape=shape,
                n_objects=self.n_objects,
                points_per_skeleton=self.points_per_skeleton,
                interpolation=self.interpolation,
                smoothness=self.smoothness,
                noise_strength=self.noise_strength,
                seed=self.seed)
            # seed=np.random.randint(10000))
            segmentation = data["segmentation"]

            # crop (more elegant & general way to do this?)
            segmentation = segmentation[:lshape[0] - inc[0], :lshape[1] -
                                        inc[1], :lshape[2] - inc[2]]
            # segmentation = segmentation[:lshape_out[i] - inc[i] for i in range(len(shape))]

            batch.arrays[array_key] = gp.Array(segmentation, array_spec)
        # self.n +=1
        return batch
Beispiel #18
0
    def provide(self, request):
        output = gp.Batch()

        timing_provide = Timing(self, "provide")
        timing_provide.start()

        spec = self.array_spec.copy()
        spec.roi = request[self.key].roi

        data = self.array[spec.roi]
        if "c" not in self.array.axes:
            # add a channel dimension
            data = np.expand_dims(data, 0)
        if np.any(np.isnan(data)):
            raise ValueError("INPUT DATA CAN'T BE NAN")
        output[self.key] = gp.Array(data, spec=spec)

        timing_provide.stop()

        output.profiling_stats.add(timing_provide)

        return output
Beispiel #19
0
    def process(self, batch, request):

        # get the raw and segmentation arrays from the current batch
        raw = batch[self.raw]
        seg = batch[self.seg]

        print(f"RAW: {raw}")
        print(f"SEG: {seg}")

        # simulate cages, return brembow volumes for raw, cages, and density
        simulated_raw = Volume(raw.data, raw.spec.voxel_size)
        cage_map, density_map = simulate_random_cages(
            simulated_raw, Volume(seg.data, seg.spec.voxel_size), self.cages,
            self.min_density, self.max_density, self.psf, True, True,
            self.no_cage_probability)

        # create array specs for new gunpowder arrays
        raw_spec = batch[self.raw].spec.copy()
        cage_map_spec = batch[self.seg].spec.copy()
        cage_map_spec.dtype = np.uint64
        density_map_spec = batch[self.seg].spec.copy()
        density_map_spec.dtype = np.float32

        # create arrays and crop to requested size
        print(cage_map_spec)
        cage_map_array = gp.Array(data=cage_map, spec=cage_map_spec)
        cage_map_array = cage_map_array.crop(request[self.cage_map].roi)
        density_map_array = gp.Array(data=density_map, spec=density_map_spec)
        density_map_array = density_map_array.crop(
            request[self.density_map].roi)

        # create a new batch with processed arrays
        processed = gp.Batch()
        processed[self.raw] = gp.Array(data=simulated_raw.data, spec=raw_spec)
        processed[self.cage_map] = cage_map_array
        processed[self.density_map] = density_map_array

        return processed
Beispiel #20
0
    def process(self, batch, request):
        outputs = gp.Batch()
        graph = batch[self.points]

        full_roi = graph.spec.roi
        size = full_roi.get_shape()
        small_roi = full_roi.copy()
        if self.centroid_size is not None:
            diff = self.centroid_size - size
            diff = diff / gp.Coordinate([2] * len(diff))
            small_roi = small_roi.grow(diff, diff)

        centered_graph = graph.crop(small_roi)

        wccs = list(graph.connected_components)
        for wcc in wccs:
            fallbacks = [x < self.node_offset for x in wcc]
            contained = [centered_graph.contains(x) for x in wcc]
            if not all([a or not b for a, b in zip(fallbacks, contained)]):
                for node in wcc:
                    graph.remove_node(gp.Node(id=node, location=None))

        outputs[self.points] = graph
    def process(self, batch, request):
        # compute stardists on label data
        data = batch.arrays[self.label_key].data
        tmp = star_dist3d_custom(data,
                                 self.rays,
                                 self.unlabeled_id,
                                 self.max_dist,
                                 invalid_value=self.invalid_value,
                                 grid=self.grid,
                                 voxel_size=self.anisotropy,
                                 mode=self.sd_mode)
        # seems unnecessary when using grid in function call above
        # tmp = tmp[self.ss_grid]
        dist = np.moveaxis(tmp, -1, 0)  # gp expects channel axis in front

        # generate spec for new batch based on what's coming in for labels
        spec = self._updated_spec(batch[self.label_key].spec)
        spec.roi = request[self.stardist_key].roi.copy()

        # assemble new array in a batch, will be added to existing batch automatically
        batch = gp.Batch()
        batch[self.stardist_key] = gp.Array(dist, spec)
        return batch
Beispiel #22
0
    def provide(self, request):
        outputs = gp.Batch()
        if str(self.snapshot_file).endswith(".h5") or str(self.snapshot_file).endswith(
            ".hdf"
        ):
            data = h5py.File(self.snapshot_file, "r")
        elif str(self.snapshot_file).endswith(".zarr"):
            data = zarr.open(self.snapshot_file, "r")
        for key, path in self.datasets.items():
            if isinstance(key, gp.ArrayKey):
                result = self.array_from_path(data, path)
                outputs[key] = result
            elif isinstance(key, gp.GraphKey):
                result = self.graph_from_path(key, data, path)
                result.relabel_connected_components()
                logger.debug(
                    f"Reading graph {key} with {result.num_vertices()} nodes, "
                    f"{result.num_edges()} edges, and {len(list(result.connected_components))} "
                    f"connected_components"
                )
                outputs[key] = result

        outputs = outputs.crop(request)
        return outputs
Beispiel #23
0
 def process(self, batch, request):
     outputs = gp.Batch()
     outputs[self.array] = copy.deepcopy(batch[self.array])
     outputs[self.array].data = batch[self.array].data.astype(np.int64)
     outputs[self.array].spec.dtype = np.int64
     return outputs
Beispiel #24
0
    def process(self, batch, request):
        num_thresholds = self.num_thresholds
        threshold_range = self.threshold_range

        outputs = gp.Batch()

        gt_graph = batch[self.gt].to_nx_graph().to_undirected()
        mst_graph = batch[self.mst].to_nx_graph().to_undirected()
        if self.connectivity is not None:
            connectivity_graph = batch[
                self.connectivity].to_nx_graph().to_undirected()

        # assert mst_graph.number_of_nodes() > 0, f"mst_graph is empty!"

        if self.details is not None:
            matching_details_graph = nx.Graph()
            if mst_graph.number_of_nodes() == 0:
                node_offset = max([node
                                   for node in mst_graph.nodes] + [-1]) + 1
                label_offset = len(list(
                    nx.connected_components(mst_graph))) + 1

                for node, attrs in mst_graph.nodes.items():
                    matching_details_graph.add_node(node,
                                                    **copy.deepcopy(attrs))
                for edge, attrs in mst_graph.edges.items():
                    matching_details_graph.add_edge(edge[0], edge[1],
                                                    **copy.deepcopy(attrs))
                for node, attrs in gt_graph.nodes.items():
                    matching_details_graph.add_node(node + node_offset,
                                                    **copy.deepcopy(attrs))
                    matching_details_graph.nodes[node + node_offset]["id"] = (
                        node + node_offset)
                for edge, attrs in gt_graph.edges.items():
                    matching_details_graph.add_edge(edge[0] + node_offset,
                                                    edge[1] + node_offset,
                                                    **copy.deepcopy(attrs))

        edges = [(edge, attrs[self.edge_threshold_attr])
                 for edge, attrs in mst_graph.edges.items()]
        edges = list(sorted(edges, key=lambda x: x[1]))
        edge_lens = [e[1] for e in edges]
        # min_threshold = edges[0][1]
        if len(edge_lens) > 0:
            min_threshold = edge_lens[int(len(edge_lens) * threshold_range[0])]
            max_threshold = edge_lens[int(len(edge_lens) * threshold_range[1])
                                      - 1]
        else:
            min_threshold = 0
            max_threshold = 1
        thresholds = np.linspace(min_threshold,
                                 max_threshold,
                                 num=num_thresholds)

        current_threshold_mst = nx.Graph()

        edge_deque = deque(edges)
        edit_distances = []
        split_costs = []
        merge_costs = []
        false_pos_costs = []
        false_neg_costs = []
        num_nodes = []
        num_edges = []

        best_score = None
        best_graph = None

        for threshold_index, threshold in enumerate(thresholds):
            logger.warning(f"Using threshold: {threshold}")
            while len(edge_deque) > 0 and edge_deque[0][1] <= threshold:
                (u, v), _ = edge_deque.popleft()
                attrs = mst_graph.edges[(u, v)]
                current_threshold_mst.add_edge(u, v)
                current_threshold_mst.add_node(u, **mst_graph.nodes[u])
                current_threshold_mst.add_node(v, **mst_graph.nodes[v])

            if self.connectivity is not None:
                temp = nx.Graph()
                next_node = max([node
                                 for node in connectivity_graph.nodes]) + 1
                for i, cc in enumerate(
                        nx.connected_components(current_threshold_mst)):
                    component_subgraph = current_threshold_mst.subgraph(cc)
                    for node in component_subgraph.nodes:
                        temp.add_node(node,
                                      **dict(connectivity_graph.nodes[node]))
                        temp.nodes[node]["component"] = i
                    for edge in connectivity_graph.edges:
                        if (edge[0] in temp.nodes and edge[1] in temp.nodes
                                and temp.nodes[edge[0]]["component"]
                                == temp.nodes[edge[1]]["component"]):
                            temp.add_edge(
                                edge[0], edge[1],
                                **dict(connectivity_graph.edges[edge]))
                        elif False:
                            path = nx.shortest_path(connectivity_graph,
                                                    edge[0], edge[1])
                            cloned_path = []
                            for node in path:
                                if node in temp.nodes:
                                    cloned_path.append(node)
                                else:
                                    cloned_path.append(next_node)
                                    next_node += 1
                            path_len = len(cloned_path) - 1
                            for i, j in zip(range(path_len),
                                            range(1, path_len + 1)):
                                u = cloned_path[i]
                                if u not in temp.nodes:
                                    temp.add_node(
                                        u,
                                        **dict(
                                            connectivity_graph.nodes[path[i]]))
                                v = cloned_path[j]
                                if v not in temp.nodes:
                                    temp.add_node(
                                        v,
                                        **dict(
                                            connectivity_graph.nodes[path[j]]))

                                temp.add_edge(
                                    u,
                                    v,
                                    **dict(connectivity_graph.edges[path[i],
                                                                    path[j]]),
                                )

            else:
                temp = copy.deepcopy(current_threshold_mst)
                for i, cc in enumerate(nx.connected_components(temp)):
                    for node in cc:
                        attrs = temp.nodes[node]
                        attrs["component"] = i

            # remove small connected_components
            false_pos_nodes = []
            for cc in nx.connected_components(temp):
                cc_graph = temp.subgraph(cc)
                min_loc = None
                max_loc = None
                for node, attrs in cc_graph.nodes.items():
                    node_loc = attrs[self.location_attr]
                    if min_loc is None:
                        min_loc = node_loc
                    else:
                        min_loc = np.min(np.array([node_loc, min_loc]), axis=0)
                    if max_loc is None:
                        max_loc = node_loc
                    else:
                        max_loc = np.max(np.array([node_loc, max_loc]), axis=0)
                if np.linalg.norm(min_loc -
                                  max_loc) < self.small_component_threshold:
                    false_pos_nodes += list(cc)
            for node in false_pos_nodes:
                temp.remove_node(node)

            nodes_x = list(temp.nodes)
            nodes_y = list(gt_graph.nodes)

            node_labels_x = {
                node: attrs["component"]
                for node, attrs in temp.nodes.items()
            }

            node_labels_y = {
                node: component
                for component, nodes in enumerate(
                    nx.connected_components(gt_graph)) for node in nodes
            }

            edges_yx = get_edges_xy(
                gt_graph,
                temp,
                location_attr=self.location_attr,
                node_match_threshold=self.comatch_threshold,
            )
            edges_xy = [(v, u) for u, v in edges_yx]

            result = match_components(
                nodes_x,
                nodes_y,
                edges_xy,
                copy.deepcopy(node_labels_x),
                copy.deepcopy(node_labels_y),
            )

            if self.details is not None:
                # add a match details graph to the batch
                # details is a graph containing nodes from both mst and gt
                # to access details of a node, use `details.nodes[node]["details"]`
                # where the details returned are a numpy array of shape (num_thresholds, _).
                # the _ values stored per threshold are success, fp, fn, merge, split, selected, mst, gt

                # NODES
                # success, mst: n matches to only nodes of 1 label, which matches its own label
                # success, gt: n matches to only nodes of 1 label, which matches its own label
                # fp: n in mst matches to nothing
                # fn: n in gt matches to nothing
                # merge: n in mst matches to a node with label not matching its own
                # split: n in gt matches to a node with label not matching its own
                # selected: n in mst in thresholded graph

                # EDGES
                # success, mst: both endpoints successful
                # success, gt: both endpoints successful
                # fp: both endpoints fp
                # fn: both endpoints fn
                # merge: e in mst: only one endpoint successful
                # split: e in gt: only one endpoint successful
                # selected: e in mst in thresholded graph
                (label_matches, node_matches, splits, merges, fps,
                 fns) = result
                # create lookup tables:
                x_label_match_lut = {}
                y_label_match_lut = {}
                for a, b in label_matches:
                    x_matches = x_label_match_lut.setdefault(a, set())
                    x_matches.add(b)
                    y_matches = y_label_match_lut.setdefault(b, set())
                    y_matches.add(a)
                x_node_match_lut = {}
                y_node_match_lut = {}
                for a, b in node_matches:
                    x_matches = x_node_match_lut.setdefault(a, set())
                    x_matches.add(b)
                    y_matches = y_node_match_lut.setdefault(b, set())
                    y_matches.add(a)

                for node, attrs in matching_details_graph.nodes.items():
                    gt = int(node >= node_offset)
                    mst = 1 - gt

                    if gt == 1:
                        node = node - node_offset

                    selected = gt or (node in temp.nodes())

                    if selected:
                        success, fp, fn, merge, split, label_pair = self.node_matching_result(
                            node,
                            gt,
                            x_label_match_lut,
                            y_label_match_lut,
                            x_node_match_lut,
                            y_node_match_lut,
                            node_labels_x,
                            node_labels_y,
                        )
                    else:
                        success, fp, fn, merge, split, label_pair = (
                            0,
                            0,
                            0,
                            0,
                            0,
                            (-1, -1),
                        )

                    data = attrs.setdefault(
                        "details", np.zeros((len(thresholds), 7), dtype=bool))
                    data[threshold_index] = [
                        selected,
                        success,
                        fp,
                        fn,
                        merge,
                        split,
                        gt,
                    ]
                    label_pairs = attrs.setdefault("label_pair", [])
                    label_pairs.append(label_pair)
                    assert len(label_pairs) == threshold_index + 1
                for (u, v), attrs in matching_details_graph.edges.items():
                    (
                        u_selected,
                        u_success,
                        u_fp,
                        u_fn,
                        u_merge,
                        u_split,
                        u_gt,
                    ) = matching_details_graph.nodes[u]["details"][
                        threshold_index]

                    (
                        v_selected,
                        v_success,
                        v_fp,
                        v_fn,
                        v_merge,
                        v_split,
                        v_gt,
                    ) = matching_details_graph.nodes[v]["details"][
                        threshold_index]

                    assert u_gt == v_gt
                    e_gt = u_gt

                    u_label_pair = matching_details_graph.nodes[u][
                        "label_pair"][threshold_index]
                    v_label_pair = matching_details_graph.nodes[v][
                        "label_pair"][threshold_index]

                    e_selected = u_selected and v_selected
                    e_success = (e_selected and u_success and v_success
                                 and (u_label_pair == v_label_pair))
                    e_fp = u_fp and v_fp
                    e_fn = u_fn and v_fn
                    e_merge = e_selected and (not e_success) and (
                        not e_fp) and not e_gt
                    e_split = e_selected and (not e_success) and (
                        not e_fn) and e_gt
                    assert not (e_success and e_merge)
                    assert not (e_success and e_split)

                    data = attrs.setdefault(
                        "details", np.zeros((len(thresholds), 7), dtype=bool))
                    if e_success:
                        label_pairs = attrs.setdefault("label_pair", [])
                        label_pairs.append(u_label_pair)
                        assert len(label_pairs) == threshold_index + 1
                    else:
                        label_pairs = attrs.setdefault("label_pair", [])
                        label_pairs.append((-1, -1))
                        assert len(label_pairs) == threshold_index + 1
                    data[threshold_index] = [
                        e_selected,
                        e_success,
                        e_fp,
                        e_fn,
                        e_merge,
                        e_split,
                        e_gt,
                    ]

            edit_distance, (
                split_cost,
                merge_cost,
                false_pos_cost,
                false_neg_cost,
            ) = psudo_graph_edit_distance(
                result[1],
                node_labels_x,
                node_labels_y,
                temp,
                gt_graph,
                self.location_attr,
                node_spacing=self.edit_distance_node_spacing,
                details=True,
            )

            edit_distances.append(edit_distance)
            split_costs.append(split_cost)
            merge_costs.append(merge_cost)
            false_pos_costs.append(false_pos_cost)
            false_neg_costs.append(false_neg_cost)
            num_nodes.append(len(temp.nodes))
            num_edges.append(len(temp.edges))

            # save the best version:
            if best_score is None:
                best_score = edit_distance
                best_graph = copy.deepcopy(temp)
            elif edit_distance < best_score:
                best_score = edit_distance
                best_graph = copy.deepcopy(temp)

        outputs[self.output] = gp.Array(
            np.array([
                edit_distances,
                thresholds,
                num_nodes,
                num_edges,
                split_costs,
                merge_costs,
                false_pos_costs,
                false_neg_costs,
            ]),
            gp.ArraySpec(nonspatial=True),
        )
        if self.output_graph is not None:
            outputs[self.output_graph] = gp.Graph.from_nx_graph(
                best_graph,
                gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False))
        if self.details is not None:
            outputs[self.details] = gp.Graph.from_nx_graph(
                matching_details_graph,
                gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False),
            )
        return outputs