def _get_points(self, inside: np.ndarray, slope: np.ndarray, bb: Roi) -> Tuple[Dict[int, Node], List[Tuple[int, int]]]: slope = slope / max(slope) shape = np.array(bb.get_shape()) outside_down = inside - shape * slope outside_up = inside + shape * slope down_intercept = self._resample_relative(inside, outside_down, bb) up_intercept = self._resample_relative(inside, outside_up, bb) points = { # line Node(id=0, location=down_intercept, attrs={ "node_type": 0, "radius": 0 }), Node(id=1, location=up_intercept, attrs={ "node_type": 0, "radius": 0 }), } edges = [Edge(0, 1)] return self._graph_points(points, edges)
def __init__(self): self.graph = Graph([ Node(1, np.array([1, 1, 1])), Node(2, np.array([500, 500, 500])), Node(3, np.array([550, 550, 550])), ], [], GraphSpec(roi=Roi((-500, -500, -500), (1500, 1500, 1500))))
def test_points_equal(self): points1 = [Node(id=1, location=np.array([0, 1]))] points2 = [Node(id=1, location=np.array([0, 1]))] self.assertTrue(self.points_equal(points1, points2)) points1 = [Node(id=2, location=np.array([1, 2]))] points2 = [Node(id=2, location=np.array([2, 1]))] self.assertFalse(self.points_equal(points1, points2))
def nodes(self): return [ Node(0, location=np.array([0, 0, 0], dtype=self.spec.dtype)), Node(1, location=np.array([1, 1, 1], dtype=self.spec.dtype)), Node(2, location=np.array([2, 2, 2], dtype=self.spec.dtype)), Node(3, location=np.array([3, 3, 3], dtype=self.spec.dtype)), Node(4, location=np.array([4, 4, 4], dtype=self.spec.dtype)), ]
def __init__(self): self.graph = Graph( [ Node(id=1, location=np.array([1, 1, 1])), Node(id=2, location=np.array([500, 500, 500])), Node(id=3, location=np.array([550, 550, 550])), ], [], GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))), )
def __init__(self): self.dtype = float self.__vertices = [ Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)), Node(id=2, location=np.array([500, 500, 500], dtype=self.dtype)), Node(id=3, location=np.array([550, 550, 550], dtype=self.dtype)), ] self.__edges = [Edge(1, 2), Edge(2, 3)] self.__spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]), Coordinate([1500, 1500, 1500]))) self.graph = Graph(self.__vertices, self.__edges, self.__spec)
def test_shift_points5(self): data = [ Node(id=0, location=np.array([3, 0])), Node(id=1, location=np.array([3, 2])), Node(id=2, location=np.array([3, 4])), Node(id=3, location=np.array([3, 6])), Node(id=4, location=np.array([3, 8])), ] spec = GraphSpec(Roi(offset=(0, 0), shape=(15, 10))) points = Graph(data, [], spec) request_roi = Roi(offset=(3, 0), shape=(9, 10)) shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]], dtype=int) lcm_voxel_size = Coordinate((3, 2)) shifted_data = [ Node(id=0, location=np.array([6, 0])), Node(id=2, location=np.array([3, 4])), Node(id=4, location=np.array([6, 8])), ] result = ShiftAugment.shift_points( points, request_roi, shift_array, shift_axis=1, lcm_voxel_size=lcm_voxel_size, ) # print("test 4", result.data, shifted_data) self.assertTrue(self.points_equal(result.nodes, shifted_data)) self.assertTrue(result.spec == GraphSpec(request_roi))
def __init__(self): self.graph = Graph( [Node(id=1, location=np.array([50, 70, 100]))], [], GraphSpec(roi=Roi((-200, -200, -200), (400, 400, 478))), )
def test_nodes(): initial_locations = { 1: np.array([1, 1, 1], dtype=np.float32), 2: np.array([500, 500, 500], dtype=np.float32), 3: np.array([550, 550, 550], dtype=np.float32), } replacement_locations = { 1: np.array([0, 0, 0], dtype=np.float32), 2: np.array([50, 50, 50], dtype=np.float32), 3: np.array([55, 55, 55], dtype=np.float32), } nodes = [ Node(id=id, location=location) for id, location in initial_locations.items() ] edges = [Edge(1, 2), Edge(2, 3)] spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]), Coordinate([1500, 1500, 1500]))) graph = Graph(nodes, edges, spec) for node in graph.nodes: node.location = replacement_locations[node.id] for node in graph.nodes: assert all(np.isclose(node.location, replacement_locations[node.id]))
def _shift_and_crop(self, points: Graph, array: Array, direction: Coordinate, output_roi: Roi): # Shift and crop the array center = array.spec.roi.get_offset() + array.spec.roi.get_shape() // 2 new_center = center + direction new_offset = new_center - output_roi.get_shape() // 2 new_roi = Roi(new_offset, output_roi.get_shape()) array = array.crop(new_roi) array.spec.roi = output_roi new_points_data = {} new_points_spec = points.spec new_points_spec.roi = new_roi new_points_graph = nx.DiGraph() # shift points and add them to a graph for point_id, point in points.data.items(): if new_roi.contains(point.location): new_point = point.copy() new_point.location = (point.location - new_offset + output_roi.get_begin()) new_points_graph.add_node( new_point.point_id, point_id=new_point.point_id, parent_id=new_point.parent_id, location=new_point.location, label_id=new_point.label_id, radius=new_point.radius, point_type=new_point.point_type, ) if points.data.get( new_point.parent_id, False) and new_roi.contains( points.data[new_point.parent_id].location): new_points_graph.add_edge(new_point.parent_id, new_point.point_id) # relabel connected components for i, connected_component in enumerate( nx.weakly_connected_components(new_points_graph)): for node in connected_component: new_points_graph.nodes[node]["label_id"] = i # store new graph data in points new_points_data = { point_id: Node( point["location"], point_id=point["point_id"], point_type=point["point_type"], radius=point["radius"], parent_id=point["parent_id"], label_id=point["label_id"], ) for point_id, point in new_points_graph.nodes.items() } points = Graph(new_points_data, new_points_spec) points.spec.roi = output_roi return points, array
def test_transpose(): voxel_size = Coordinate((20, 20)) graph_key = GraphKey("GRAPH") array_key = ArrayKey("ARRAY") graph = Graph( [Node(id=1, location=np.array([450, 550]))], [], GraphSpec(roi=Roi((100, 200), (800, 600))), ) data = np.zeros([40, 30]) data[17, 17] = 1 array = Array( data, ArraySpec(roi=Roi((100, 200), (800, 600)), voxel_size=voxel_size)) default_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0])) transpose_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1])) request = BatchRequest() request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300))) request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300))) with build(default_pipeline): expected_location = [450, 550] batch = default_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}" with build(transpose_pipeline): expected_location = [410, 590] batch = transpose_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
def test_shift_points3(self): data = [Node(id=1, location=np.array([0, 1]))] spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) points = Graph(data, [], spec) request_roi = Roi(offset=(0, 1), shape=(5, 3)) shift_array = np.array([[0, 1], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) lcm_voxel_size = Coordinate((1, 1)) shifted_points = Graph([Node(id=1, location=np.array([0, 2]))], [], GraphSpec(request_roi)) result = ShiftAugment.shift_points( points, request_roi, shift_array, shift_axis=0, lcm_voxel_size=lcm_voxel_size, ) # print("test 3", result.data, shifted_points.data) self.assertTrue(self.points_equal(result.nodes, shifted_points.nodes)) self.assertTrue(result.spec == GraphSpec(request_roi))
def setup(self): roi = Roi(Coordinate([0] * len(self.size)), self.size) for points_key in self.points: self.provides(points_key, GraphSpec(roi=roi, directed=self.directed)) k = min(self.size) nodes = [ Node(id=i, location=np.array([i * k / self.num_points] * 3)) for i in range(self.num_points) ] edges = [Edge(i, i + 1) for i in range(self.num_points - 1)] self.graph = Graph(nodes, edges, GraphSpec(roi=roi, directed=self.directed))
def provide(self, request): # print("ScanTestSource: Got request " + str(request)) batch = Batch() # have the pixels encode their position for (array_key, spec) in request.array_specs.items(): roi = spec.roi roi_voxel = roi // self.spec[array_key].voxel_size # print("ScanTestSource: Adding " + str(array_key)) # the z,y,x coordinates of the ROI meshgrids = np.meshgrid(range(roi_voxel.get_begin()[0], roi_voxel.get_end()[0]), range(roi_voxel.get_begin()[1], roi_voxel.get_end()[1]), range(roi_voxel.get_begin()[2], roi_voxel.get_end()[2]), indexing='ij') data = meshgrids[0] + meshgrids[1] + meshgrids[2] # print("Roi is: " + str(roi)) spec = self.spec[array_key].copy() spec.roi = roi batch.arrays[array_key] = Array(data, spec) for graph_key, spec in request.graph_specs.items(): # node at x, y, z if x%100==0, y%10==0, z%10==0 nodes = [] start = spec.roi.get_begin() - tuple( x % s for x, s in zip(spec.roi.get_begin(), [100, 10, 10])) for i, j, k in itertools.product(*[ range(a, b, s) for a, b, s in zip( start, spec.roi.get_end(), [100, 10, 10]) ]): location = np.array([i, j, k]) if spec.roi.contains(location): nodes.append( Node(id=coordinate_to_id(i, j, k), location=location)) batch.graphs[graph_key] = Graph(nodes, [], spec) return batch
def setup(self): self.points = [ Node( i, np.array([(i // 100) % 10 * 4, (i // 10) % 10 * 4, i % 10 * 4])) for i in range(1000) ] self.provides( GraphKeys.TEST_POINTS, GraphSpec(roi=Roi((-40, -40, -40), (120, 120, 120))), ) self.provides( ArrayKeys.TEST_LABELS, ArraySpec( roi=Roi((-40, -40, -40), (120, 120, 120)), voxel_size=Coordinate((4, 1, 1)), interpolatable=False, ), )
def provide(self, request): outputs = Batch() nodes = [ Node(id=0, location=np.array((1, 1, 1))), Node(id=1, location=np.array((10, 10, 10))), Node(id=2, location=np.array((19, 19, 19))), Node(id=3, location=np.array((21, 21, 21))), Node(id=104, location=np.array((30, 30, 30))), Node(id=5, location=np.array((39, 39, 39))), ] edges = [Edge(0, 1), Edge(1, 2), Edge(3, 104), Edge(104, 5)] spec = self.spec[GraphKeys.RAW].copy() spec.roi = request[GraphKeys.RAW].roi graph = Graph(nodes, edges, spec) outputs[GraphKeys.RAW] = graph.crop(spec.roi) return outputs
def setup(self): self.points = [ Node(0, np.array([0, 0, 0])), Node(1, np.array([0, 10, 0])), Node(2, np.array([0, 20, 0])), Node(3, np.array([0, 30, 0])), Node(4, np.array([0, 40, 0])), Node(5, np.array([0, 50, 0])), ] self.provides(GraphKeys.TEST_POINTS, GraphSpec(roi=Roi((-100, -100, -100), (200, 200, 200)))) self.provides( ArrayKeys.TEST_LABELS, ArraySpec( roi=Roi((-100, -100, -100), (200, 200, 200)), voxel_size=Coordinate((4, 1, 1)), interpolatable=False, ), )
def __get_pre_and_postsyn_locations(self, roi): presyn_locs, postsyn_locs = {}, {} min_dist_between_presyn_locs = 250 voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size min_dist_pre_to_postsyn_loc, max_dist_pre_to_postsyn_loc = 60, 120 num_presyn_locations = roi.size() // ( np.prod(50 * np.asarray(voxel_size_points)) ) # 1 synapse per 50vx^3 cube num_postsyn_locations = np.random.randint( low=1, high=3 ) # 1 to 3 postsyn partners loc_id = 0 all_presyn_locs = [] for nr_presyn_loc in range(num_presyn_locations): loc_id = loc_id + 1 presyn_loc_id = loc_id presyn_loc_too_close = True while presyn_loc_too_close: presyn_location = np.asarray( [ np.random.randint( low=roi.get_begin()[0], high=roi.get_end()[0] ), np.random.randint( low=roi.get_begin()[1], high=roi.get_end()[1] ), np.random.randint( low=roi.get_begin()[2], high=roi.get_end()[2] ), ] ) # ensure that partner locations of diff presyn locations are not overlapping presyn_loc_too_close = False for previous_loc in all_presyn_locs: if np.linalg.norm(presyn_location - previous_loc) < ( min_dist_between_presyn_locs ): presyn_loc_too_close = True syn_id = nr_presyn_loc partner_ids = [] for nr_partner_loc in range(num_postsyn_locations): loc_id = loc_id + 1 partner_ids.append(loc_id) postsyn_loc_is_inside = False while not postsyn_loc_is_inside: postsyn_location = presyn_location + np.random.choice( (-1, 1), size=3, replace=True ) * np.random.randint( min_dist_pre_to_postsyn_loc, max_dist_pre_to_postsyn_loc, size=3 ) if roi.contains(Coordinate(postsyn_location)): postsyn_loc_is_inside = True postsyn_locs[int(loc_id)] = deepcopy( Node( loc_id, location=postsyn_location, attrs={ "location_id": loc_id, "synapse_id": syn_id, "partner_ids": [presyn_loc_id], "props": {}, }, ) ) presyn_locs[int(presyn_loc_id)] = deepcopy( Node( presyn_loc_id, location=presyn_location, attrs={ "location_id": presyn_loc_id, "synapse_id": syn_id, "partner_ids": partner_ids, "props": {}, }, ) ) return presyn_locs, postsyn_locs
def test_output(self): GraphKey("TEST_GRAPH") pipeline = ExampleGraphSource() + GrowFilter() with build(pipeline): batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (50, 50, 50))) })) graph = batch[GraphKeys.TEST_GRAPH] expected_vertices = ( Node(id=1, location=np.array([1.0, 1.0, 1.0], dtype=float)), Node( id=2, location=np.array([50.0, 50.0, 50.0], dtype=float), temporary=True, ), ) seen_vertices = tuple(graph.nodes) self.assertCountEqual( [v.original_id for v in expected_vertices], [v.original_id for v in seen_vertices], ) for expected, actual in zip( sorted(expected_vertices, key=lambda v: tuple(v.location)), sorted(seen_vertices, key=lambda v: tuple(v.location)), ): assert all(np.isclose(expected.location, actual.location)) batch = pipeline.request_batch( BatchRequest({ GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((25, 25, 25), (500, 500, 500))) })) graph = batch[GraphKeys.TEST_GRAPH] expected_vertices = ( Node( id=1, location=np.array([25.0, 25.0, 25.0], dtype=float), temporary=True, ), Node(id=2, location=np.array([500.0, 500.0, 500.0], dtype=float)), Node( id=3, location=np.array([525.0, 525.0, 525.0], dtype=float), temporary=True, ), ) seen_vertices = tuple(graph.nodes) self.assertCountEqual( [v.original_id for v in expected_vertices], [v.original_id for v in seen_vertices], ) for expected, actual in zip( sorted(expected_vertices, key=lambda v: tuple(v.location)), sorted(seen_vertices, key=lambda v: tuple(v.location)), ): assert all(np.isclose(expected.location, actual.location))
def _toy_swc_points(self): """ shape: ----------- | | |---------- | | ----------- """ arr = np.array points = [ # backbone Node(id=0, location=arr([0, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=1, location=arr([1, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=2, location=arr([2, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=3, location=arr([3, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=4, location=arr([4, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=5, location=arr([5, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=6, location=arr([6, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=7, location=arr([7, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=8, location=arr([8, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=9, location=arr([9, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=10, location=arr([10, 0, 5]), attrs={ "radius": 0, "node_type": 0 }), # bottom line Node(id=11, location=arr([0, 1, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=12, location=arr([0, 2, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=13, location=arr([0, 3, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=14, location=arr([0, 4, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=15, location=arr([0, 5, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=16, location=arr([0, 6, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=17, location=arr([0, 7, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=18, location=arr([0, 8, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=19, location=arr([0, 9, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=20, location=arr([0, 10, 5]), attrs={ "radius": 0, "node_type": 0 }), # mid line Node(id=21, location=arr([5, 1, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=22, location=arr([5, 2, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=23, location=arr([5, 3, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=24, location=arr([5, 4, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=25, location=arr([5, 5, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=26, location=arr([5, 6, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=27, location=arr([5, 7, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=28, location=arr([5, 8, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=29, location=arr([5, 9, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=30, location=arr([5, 10, 5]), attrs={ "radius": 0, "node_type": 0 }), # top line Node(id=31, location=arr([10, 1, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=32, location=arr([10, 2, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=33, location=arr([10, 3, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=34, location=arr([10, 4, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=35, location=arr([10, 5, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=36, location=arr([10, 6, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=37, location=arr([10, 7, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=38, location=arr([10, 8, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=39, location=arr([10, 9, 5]), attrs={ "radius": 0, "node_type": 0 }), Node(id=40, location=arr([10, 10, 5]), attrs={ "radius": 0, "node_type": 0 }), ] edges = [ Edge(0, 0), Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 5), Edge(5, 6), Edge(6, 7), Edge(7, 8), Edge(8, 9), Edge(9, 10), Edge(0, 11), Edge(11, 12), Edge(12, 13), Edge(13, 14), Edge(14, 15), Edge(15, 16), Edge(16, 17), Edge(17, 18), Edge(18, 19), Edge(19, 20), Edge(5, 21), Edge(21, 22), Edge(22, 23), Edge(23, 24), Edge(24, 25), Edge(25, 26), Edge(26, 27), Edge(27, 28), Edge(28, 29), Edge(29, 30), Edge(10, 31), Edge(31, 32), Edge(32, 33), Edge(33, 34), Edge(34, 35), Edge(35, 36), Edge(36, 37), Edge(37, 38), Edge(38, 39), Edge(39, 40), ] return Graph( points, edges, GraphSpec( roi=Roi(Coordinate((-100, -100, -100)), Coordinate((300, 300, 300))), directed=True, ), )
def process(self, batch, request): outputs = Batch() raw_base_spec = batch[self.raw_base].spec.copy() # Get base arrays raw_base_array = batch[self.raw_base].data labels_base_array = batch[self.labels_base].data # Get add arrays raw_add_array = batch[self.raw_add].data labels_add_array = batch[self.labels_add].data if self.scale_add_volume: raw_base_median = np.median(raw_base_array) raw_add_median = np.median(raw_add_array) diff = raw_base_median - raw_add_median raw_add_array = raw_add_array + diff # fuse labels fused_labels_array = self._relabel(labels_base_array) next_label_id = np.max(fused_labels_array) + 1 add_mask = np.zeros_like(fused_labels_array, dtype=bool) for label in np.unique(labels_add_array): if label == 0: continue label_mask = labels_add_array == label # handle overlap overlap = np.logical_and(fused_labels_array, label_mask) fused_labels_array[overlap] = -1 # assign new label add_mask[label_mask] = True fused_labels_array[label_mask] = next_label_id next_label_id += 1 # fuse raw if self.blend_mode == "intensity": add_mask = raw_add_array.astype(np.float32) / np.max(raw_add_array) raw_fused_array = add_mask * raw_add_array + (1 - add_mask) * raw_base_array elif self.blend_mode == "add": raw_fused_array = 0.5*raw_add_array / np.max( raw_add_array ) + 0.5*raw_base_array / np.max(raw_base_array) raw_fused_array = np.clip(raw_fused_array, 0, 1) elif self.blend_mode == "labels_mask": soft_mask = np.zeros_like(add_mask, dtype="float32") ndimage.gaussian_filter( add_mask.astype("float32"), sigma=self.blend_smoothness / np.array(raw_base_spec.voxel_size), output=soft_mask, mode=self.gaussian_smooth_mode, ) soft_mask /= np.clip(np.max(soft_mask), 1e-5, float("inf")) soft_mask = np.clip((soft_mask * 2), 0, 1) if self.soft_mask is not None: outputs.arrays[self.soft_mask] = Array( soft_mask, spec=ArraySpec( roi=raw_base_spec.roi, voxel_size=raw_base_spec.voxel_size ), ) if self.masked_base is not None: outputs.arrays[self.masked_base] = Array( raw_base_array * (soft_mask > 0.25), spec=raw_base_spec.copy() ) if self.masked_add is not None: outputs.arrays[self.masked_add] = Array( raw_add_array * soft_mask, spec=ArraySpec( roi=raw_base_spec.roi, voxel_size=raw_base_spec.voxel_size ), ) if self.mask_maxed is not None: outputs.arrays[self.mask_maxed] = Array( np.maximum( raw_base_array * (soft_mask > 0.25), raw_add_array * soft_mask ), spec=ArraySpec( roi=raw_base_spec.roi, voxel_size=raw_base_spec.voxel_size ), ) raw_fused_array = np.maximum(soft_mask * raw_add_array, raw_base_array) raw_fused_array = np.clip(raw_fused_array, 0, 1) else: raise NotImplementedError("Unknown blend mode %s." % self.blend_mode) # load specs labels_add_spec = batch[self.labels_add].spec.copy() raw_base_spec = batch[self.raw_base].spec.copy() raw_dtype = batch[self.raw_base].data.dtype raw_base_spec.dtype = raw_dtype # return raw and labels for "fused" volume # raw_fused_array.astype(raw_base_spec.dtype) outputs.arrays[self.raw_fused] = Array( data=raw_fused_array.astype(raw_base_spec.dtype), spec=raw_base_spec ) outputs.arrays[self.labels_fused] = Array( data=fused_labels_array, spec=labels_add_spec ) # fuse points: if self.points_fused in request: node_ids = [node.id for node in batch.graphs[self.points_base].nodes] num_nodes = len(node_ids) offset = 0 if num_nodes == 0 else max(node_ids) + 1 fused_graph = batch.graphs[self.points_base].copy() for node in batch.graphs[self.points_add].nodes: attrs = deepcopy(node.all) attrs["id"] += offset fused_graph.add_node(Node.from_attrs(attrs)) for edge in batch.graphs[self.points_add].edges: edge = Edge(edge.u + offset, edge.v + offset) fused_graph.add_edge(edge) outputs.graphs[self.points_fused] = fused_graph return outputs