def _shift_and_crop(self, points: Graph, array: Array, direction: Coordinate, output_roi: Roi): # Shift and crop the array center = array.spec.roi.get_offset() + array.spec.roi.get_shape() // 2 new_center = center + direction new_offset = new_center - output_roi.get_shape() // 2 new_roi = Roi(new_offset, output_roi.get_shape()) array = array.crop(new_roi) array.spec.roi = output_roi new_points_data = {} new_points_spec = points.spec new_points_spec.roi = new_roi new_points_graph = nx.DiGraph() # shift points and add them to a graph for point_id, point in points.data.items(): if new_roi.contains(point.location): new_point = point.copy() new_point.location = (point.location - new_offset + output_roi.get_begin()) new_points_graph.add_node( new_point.point_id, point_id=new_point.point_id, parent_id=new_point.parent_id, location=new_point.location, label_id=new_point.label_id, radius=new_point.radius, point_type=new_point.point_type, ) if points.data.get( new_point.parent_id, False) and new_roi.contains( points.data[new_point.parent_id].location): new_points_graph.add_edge(new_point.parent_id, new_point.point_id) # relabel connected components for i, connected_component in enumerate( nx.weakly_connected_components(new_points_graph)): for node in connected_component: new_points_graph.nodes[node]["label_id"] = i # store new graph data in points new_points_data = { point_id: Node( point["location"], point_id=point["point_id"], point_type=point["point_type"], radius=point["radius"], parent_id=point["parent_id"], label_id=point["label_id"], ) for point_id, point in new_points_graph.nodes.items() } points = Graph(new_points_data, new_points_spec) points.spec.roi = output_roi return points, array
def seperate_using_kdtrees( self, base_batch: Batch, add_batch: Batch, output_roi: Roi, final=False, goal: float = 0, epsilon: float = 0.1, ): points_add = add_batch.graphs.get( self.point_source, add_batch.graphs.get(self.nonempty_placeholder, None)) points_base = base_batch.graphs.get( self.point_source, base_batch.graphs.get(self.nonempty_placeholder, None)) if len(list(points_add.nodes)) < 1 or len(list(points_base.nodes)) < 1: return Coordinate([0, 0, 0]) # shift add points to start at [0,0,0] add_locations = np.array([ point.location - points_add.spec.roi.get_begin() for point in points_add.nodes ]) add_tree = cKDTree(add_locations) # shift base points to start at [0,0,0] base_locations = np.array([ point.location - points_base.spec.roi.get_begin() for point in points_base.nodes ]) base_tree = cKDTree(base_locations) input_shape = points_base.spec.roi.get_shape() output_shape = output_roi.get_shape() input_radius = input_shape / 2 output_radius = output_shape / 2 current_shift = np.array([0, 0, 0], dtype=float) # in voxels radius = max(output_radius) max_shift = input_radius - output_radius - Coordinate([1, 1, 1]) for i in range(self.shift_attempts * 10): shift_attempt = Coordinate(current_shift) add_clipped_roi = Roi(input_radius - output_radius + shift_attempt, output_shape) base_clipped_roi = Roi( input_radius - output_radius - shift_attempt, output_shape) # query points in trees below certain distance from shifted center clipped_add_points = add_tree.query_ball_point(input_radius + shift_attempt, radius, p=float("inf")) clipped_base_points = base_tree.query_ball_point(input_radius - shift_attempt, radius, p=float("inf")) # if queried points are empty, skip if len(clipped_add_points) < 1 and len(clipped_base_points) < 1: logger.debug(f"no points in centered roi!") continue # apply twice the shift to add points current_check = add_locations - shift_attempt * 2 # get all points in base that are close to shifted add points points_too_close = base_tree.query_ball_point(current_check, goal) # calculate next shift direction = np.zeros([3]) min_dist = float("inf") count = 0 for node_a, neighbors in enumerate(points_too_close): if node_a not in clipped_add_points or not add_clipped_roi.contains( add_locations[node_a, :]): continue for neighbor in neighbors: if (neighbor not in clipped_base_points or not base_clipped_roi.contains( base_locations[neighbor, :])): continue vector = (base_locations[neighbor, :] - base_clipped_roi.get_begin()) - ( add_locations[node_a, :] - add_clipped_roi.get_begin()) mag = np.linalg.norm(vector) min_dist = min(min_dist, mag) unit_vector = vector / (mag + 1) # want to move at most n units if mag is 0, or 0 units if mag is n direction += (goal - mag) * unit_vector count += 1 if (count == 0 or goal - goal * epsilon - epsilon <= min_dist <= goal + goal * epsilon + epsilon): logger.debug( f"shift: {shift_attempt} worked with {min_dist} and {count}" ) return shift_attempt logger.debug( f"min dist {min_dist} not in {goal - goal*epsilon, goal + goal*epsilon} " f"with shift: {current_shift}") direction /= count if np.linalg.norm(direction) < 1e-2: logger.debug(f"Moving too slow. Probably stuck!") return None current_shift += direction + (np.random.random(3) - 0.5) np.clip(current_shift, -max_shift, max_shift, out=current_shift) logger.debug(f"Request failed at {current_shift}. New Request!") if Path("test_output").exists(): if not Path("test_output", "distances.obj").exists(): pickle.dump(base_batch, open(Path("test_output", "batch_base.obj"), "wb")) if not Path("test_output", "distances.obj").exists(): pickle.dump(add_batch, open(Path("test_output", "batch_add.obj"), "wb")) if final: return current_shift else: return None