def build_trees_from_mst(emst, edges_u, edges_v, alpha, coordinate_scale, offset, voxel_size): trees = nx.DiGraph() ndims = len(voxel_size) for edge, u, v in zip(np.array(emst), np.array(edges_u), np.array(edges_v)): if edge[2] > alpha: continue pos_u = daisy.Coordinate((0, ) * (3 - ndims) + tuple((u[-ndims:] / coordinate_scale) + (offset / voxel_size))) pos_v = daisy.Coordinate((0, ) * (3 - ndims) + tuple((v[-ndims:] / coordinate_scale) + (offset / voxel_size))) if edge[0] not in trees.nodes: trees.add_node(edge[0], pos=pos_u) else: assert trees.nodes[ edge[0]]["pos"] == pos_u, "locations don't match" if edge[1] not in trees.nodes: trees.add_node(edge[1], pos=pos_v) else: assert trees.nodes[ edge[1]]["pos"] == pos_v, "locations don't match" trees.add_edge(edge[0], edge[1], d=edge[2]) return trees
def build_trees(node_ids, locations, edges, node_attrs=None, edge_attrs=None): if node_attrs is None: node_attrs = {} if edge_attrs is None: edge_attrs = {} node_to_index = {n: i for i, n in enumerate(node_ids)} trees = nx.Graph() pbs = { int(node_id): node_location for node_id, node_location in zip(node_ids, locations) } for i, row in enumerate(edges): u = node_to_index.get(int(row[0]), -1) v = node_to_index.get(int(row[-1]), -1) e_attrs = {attr: values[i] for attr, values in edge_attrs.items()} if u == -1 or v == -1: continue pos_u = daisy.Coordinate(tuple(pbs[node_ids[u]])) pos_v = daisy.Coordinate(tuple(pbs[node_ids[v]])) if node_ids[u] not in trees.nodes: u_attrs = {attr: values[u] for attr, values in node_attrs.items()} trees.add_node(node_ids[u], location=pos_u, **u_attrs) if node_ids[v] not in trees.nodes: v_attrs = {attr: values[v] for attr, values in node_attrs.items()} trees.add_node(node_ids[v], location=pos_v, **v_attrs) trees.add_edge(node_ids[u], node_ids[v], **e_attrs) return trees
def solve(predict_config, worker_config, data_config, graph_config, solve_config, num_block_workers, block_size, roi_offset, roi_size, context, solve_block, base_dir, experiment, train_number, predict_number, graph_number, solve_number, queue, singularity_container, mount_dirs, **kwargs): source_roi = daisy.Roi(daisy.Coordinate(roi_offset), daisy.Coordinate(roi_size)) solve_setup_dir = os.path.join( os.path.join(base_dir, experiment), "04_solve/setup_t{}_p{}_g{}_s{}".format(train_number, predict_number, graph_number, solve_number)) block_write_roi = daisy.Roi((0, 0, 0), block_size) block_read_roi = block_write_roi.grow(context, context) total_roi = source_roi.grow(context, context) logger.info("Solving in %s", total_roi) daisy.run_blockwise( total_roi, block_read_roi, block_write_roi, process_function=lambda: start_worker(predict_config, worker_config, data_config, graph_config, solve_config, queue, singularity_container, mount_dirs, solve_block, solve_setup_dir), num_workers=num_block_workers, fit='shrink') logger.info("Finished solving, parameters id is %s", solve_number)
def get_source_roi(data_dir, sample): sample_path = os.path.join(data_dir, sample) # get absolute paths if os.path.isfile(sample_path) or sample.endswith((".zarr", ".n5")): sample_dir = os.path.abspath( os.path.join(data_dir, os.path.dirname(sample))) else: sample_dir = os.path.abspath(os.path.join(data_dir, sample)) if os.path.isfile(os.path.join(sample_dir, 'attributes.json')): with open(os.path.join(sample_dir, 'attributes.json'), 'r') as f: attributes = json.load(f) voxel_size = daisy.Coordinate(attributes['resolution']) shape = daisy.Coordinate(attributes['shape']) offset = daisy.Coordinate(attributes['offset']) source_roi = daisy.Roi(offset, shape * voxel_size) return voxel_size, source_roi elif os.path.isdir(os.path.join(sample_dir, 'timelapse.zarr')): a = daisy.open_ds(os.path.join(sample_dir, 'timelapse.zarr'), 'volumes/raw') return a.voxel_size, a.roi else: raise RuntimeError( "Can't find attributes.json or timelapse.zarr in %s" % sample_dir)
def read_data_config(data_config): config = configparser.ConfigParser() config.read(data_config) cfg_dict = {} # Data cfg_dict["sample"] = config.get("Data", "sample") # Don't think I need these offset = config.get("Data", "roi_offset") size = config.get("Data", "roi_size") cfg_dict["roi_offset"] = daisy.Coordinate( tuple(int(x) for x in offset.split(", ")) if not offset == "None" else [None] * 3 ) cfg_dict["roi_size"] = daisy.Coordinate( tuple(int(x) for x in size.split(", ")) if not size == "None" else [None] * 3 ) cfg_dict["roi"] = daisy.Roi(cfg_dict["roi_offset"], cfg_dict["roi_size"]) cfg_dict["location_attr"] = config.get("Data", "location_attr") cfg_dict["penalty_attr"] = config.get("Data", "penalty_attr") cfg_dict["target_edge_len"] = int(config.get("Data", "target_edge_len")) # Database cfg_dict[ "consensus_db" ] = f"mouselight-{cfg_dict['sample']}-{config.get('Data', 'consensus_db')}" cfg_dict[ "subdivided_db" ] = f"mouselight-{cfg_dict['sample']}-{config.get('Data', 'subdivided_db')}" cfg_dict["db_host"] = config.get("Data", "db_host") return cfg_dict
def build_trees(node_ids, locations, edges, voxel_size): trees = nx.DiGraph() pbs = { int(node_id): node_location for node_id, node_location in zip(node_ids, locations) } for row in edges: u = int(row[0]) v = int(row[-1]) if u == -1 or v == -1: continue pos_u = daisy.Coordinate(tuple(pbs[u])) / voxel_size pos_v = daisy.Coordinate(tuple(pbs[v])) / voxel_size if u not in trees.nodes: trees.add_node(u, pos=pos_u) else: assert trees.nodes[u]["pos"] == pos_u, "locations don't match" if v not in trees.nodes: trees.add_node(v, pos=pos_v) else: assert trees.nodes[v]["pos"] == pos_v, "locations don't match" trees.add_edge(u, v, d=np.linalg.norm(pos_u - pos_v)) return trees
def test_load_klb(self): self.write_test_klb() data = daisy.open_ds(self.klb_file, None) self.assertEqual(data.roi.get_offset(), daisy.Coordinate((0, 0, 0, 0))) self.assertEqual(data.roi.get_shape(), daisy.Coordinate( (1, 20, 10, 10))) self.assertEqual(data.voxel_size, daisy.Coordinate((1, 2, 1, 1))) self.remove_test_klb()
def extract_edges_blockwise(db_host, db_name, sample, edge_move_threshold, block_size, num_workers, frames=None, frame_context=1, data_dir='../01_data', use_pv_distance=False, **kwargs): voxel_size, source_roi = get_source_roi(data_dir, sample) # limit to specific frames, if given if frames: begin, end = frames begin -= frame_context end += frame_context crop_roi = daisy.Roi((begin, None, None, None), (end - begin, None, None, None)) source_roi = source_roi.intersect(crop_roi) # block size in world units block_write_roi = daisy.Roi((0, ) * 4, daisy.Coordinate(block_size)) pos_context = daisy.Coordinate((0, ) + (edge_move_threshold, ) * 3) neg_context = daisy.Coordinate((1, ) + (edge_move_threshold, ) * 3) logger.debug("Set neg context to %s", neg_context) input_roi = source_roi.grow(neg_context, pos_context) block_read_roi = block_write_roi.grow(neg_context, pos_context) print("Following ROIs in world units:") print("Input ROI = %s" % input_roi) print("Block read ROI = %s" % block_read_roi) print("Block write ROI = %s" % block_write_roi) print("Output ROI = %s" % source_roi) print("Starting block-wise processing...") # process block-wise daisy.run_blockwise(input_roi, block_read_roi, block_write_roi, process_function=lambda b: extract_edges_in_block( db_name, db_host, edge_move_threshold, b, use_pv_distance=use_pv_distance), check_function=lambda b: check_function( b, 'extract_edges', db_name, db_host), num_workers=num_workers, processes=True, read_write_conflict=False, fit='shrink')
def ingest_consensus(sample: str, consensus_dir: Path, transform_file: Path, url: str): """ Storing data in MongoDB using psuedo world coords (1,.3,.3) microns rather than the slightly off floats found in the transform.txt file. """ mongo_graph_provider = daisy.persistence.MongoDbGraphProvider( f"mouselight-{sample}-consensus", url, directed=True, mode="w") graph = mongo_graph_provider.get_graph( daisy.Roi(daisy.Coordinate([None, None, None]), daisy.Coordinate([None, None, None]))) consensus_graphs = [] for consensus_neuron in tqdm(consensus_dir.iterdir(), "Consensus neurons: "): if (not consensus_neuron.is_dir() or not (consensus_neuron / "consensus.swc").exists()): continue consensus_graph = parse_consensus( consensus_neuron / "consensus.swc", consensus_neuron / "dendrite.swc", transform, offset=np.array([0, 0, 0]), resolution=np.array([300, 300, 1000]), transpose=[2, 1, 0], ) for node in consensus_graph.nodes: consensus_graph.nodes[node]["position"] = consensus_graph.nodes[ node]["location"].tolist() del consensus_graph.nodes[node]["location"] consensus_graphs.append(consensus_graph) logger.info("Consolidating consensus graphs!") consensus_graph = nx.disjoint_union_all(consensus_graphs) data = {} for node_id, attrs in consensus_graph.nodes.items(): node_id = int(np.int64(node_id)) node_ids = data.setdefault("id", []) node_ids.append(node_id) for key, value in attrs.items(): dlist = data.setdefault(key, []) dlist.append(value) logger.info( f"Writing {len(consensus_graph.nodes)} nodes and {len(consensus_graph.edges)} edges!" ) bulk_write_nodes(url, f"mouselight-{sample}-consensus", "nodes", data) bulk_write_edges( url, f"mouselight-{sample}-consensus", "edges", ("u", "v"), list(consensus_graph.edges), True, )
def __init__(self, positions, dataset, size): self.dataset = dataset self.data = dataset.open_daisy() self.voxel_size = daisy.Coordinate(dataset.voxel_size) self.size = daisy.Coordinate(size) self.size_nm = self.size * self.voxel_size self.positions = self.filter_positions(positions) self.transform = None
def test_overwrite_attrs(self): self.write_test_klb() data = daisy.open_ds(self.klb_file, None, attr_filename=self.attrs_file) self.assertEqual(data.roi.get_offset(), daisy.Coordinate((0, 100, 10, 0))) self.assertEqual(data.roi.get_shape(), daisy.Coordinate( (1, 10, 20, 20))) self.assertEqual(data.voxel_size, daisy.Coordinate((1, 1, 2, 2))) self.remove_test_klb()
def get_raw(locs, size, voxel_size, data_container, data_set): """ Get raw crops from the specified dataset. locs(``list of tuple of ints``): list of centers of location of interest size(``tuple of ints``): size of cropout in voxel voxel_size(``tuple of ints``): size of a voxel data_container(``string``): path to data container (e.g. zarr file) data_set(``string``): corresponding data_set name, (e.g. raw) """ raw = [] size = daisy.Coordinate(size) voxel_size = daisy.Coordinate(voxel_size) size_nm = (size * voxel_size) dataset = daisy.open_ds(data_container, data_set) for loc in locs: offset_nm = loc - (size / 2 * voxel_size) roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size, mode='closest') if roi.get_shape()[0] != size[0]: roi.set_shape(size_nm) if not dataset.roi.contains(roi): logger.WARNING("Location %s is not fully contained in dataset" % loc) return raw.append(dataset[roi].to_ndarray()) raw = np.stack(raw) raw = raw.astype(np.float32) raw_normalized = raw / 255.0 raw_normalized = raw_normalized * 2.0 - 1.0 return raw, raw_normalized
def extract_edges( db_host, db_name, soft_mask_container, soft_mask_dataset, roi_offset, roi_size, distance_threshold, block_size, num_block_workers, graph_number, **kwargs): # Define Rois: source_roi = daisy.Roi(roi_offset, roi_size) block_write_roi = daisy.Roi( (0,) * 3, daisy.Coordinate(block_size)) pos_context = daisy.Coordinate((distance_threshold,)*3) neg_context = daisy.Coordinate((distance_threshold,)*3) logger.debug("Set pos context to %s", pos_context) logger.debug("Set neg context to %s", neg_context) input_roi = source_roi.grow(neg_context, pos_context) block_read_roi = block_write_roi.grow(neg_context, pos_context) logger.info("Following ROIs in world units:") logger.info("Input ROI = %s" % input_roi) logger.info("Block read ROI = %s" % block_read_roi) logger.info("Block write ROI = %s" % block_write_roi) logger.info("Output ROI = %s" % source_roi) logger.info("Starting block-wise processing...") # process block-wise daisy.run_blockwise( input_roi, block_read_roi, block_write_roi, process_function=lambda b: extract_edges_in_block( db_name, db_host, soft_mask_container, soft_mask_dataset, distance_threshold, graph_number, b), num_workers=num_block_workers, processes=True, read_write_conflict=False, fit='shrink')
def test_none(self): a = daisy.Coordinate((None, 1, 2)) b = daisy.Coordinate((3, 4, None)) assert a + b == (None, 5, None) assert a - b == (None, -3, None) assert a / b == (None, 0, None) assert a // b == (None, 0, None) assert b / a == (None, 4, None) assert b // a == (None, 4, None) assert abs(a) == (None, 1, 2) assert abs(-a) == (None, 1, 2)
def visualize_npy(npy_file: Path, voxel_size): voxel_size = daisy.Coordinate(voxel_size) viewer = neuroglancer.Viewer() with viewer.txn() as s: v = np.load(npy_file) m = daisy.Array( v, daisy.Roi(daisy.Coordinate([0, 0, 0]), daisy.Coordinate(v.shape)), daisy.Coordinate([1, 1, 1]), ) add_layer(s, m, f"npy array") print(viewer) input("Hit ENTER to quit!")
def get_raw_parallel(locs, size, voxel_size, data_container, data_set): """ Get raw crops from the specified dataset. locs(``list of tuple of ints``): list of centers of location of interest size(``tuple of ints``): size of cropout in voxel voxel_size(``tuple of ints``): size of a voxel data_container(``string``): path to data container (e.g. zarr file) data_set(``string``): corresponding data_set name, (e.g. raw) """ pool = Pool(processes=len(locs)) raw = [] size = daisy.Coordinate(size) voxel_size = daisy.Coordinate(voxel_size) size_nm = (size * voxel_size) dataset = daisy.open_ds(data_container, data_set) raw_workers = [ pool.apply_async(fetch_from_ds, (dataset, loc, voxel_size, size, size_nm)) for loc in locs ] raw = [w.get(timeout=60) for w in raw_workers] pool.close() pool.join() raw = np.stack(raw) raw = raw.astype(np.float32) raw_normalized = raw / 255.0 raw_normalized = raw_normalized * 2.0 - 1.0 return raw, raw_normalized
def run_test_graph_write_roi(self, provider_factory): graph_provider = provider_factory('w') graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, comment="without position") graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip='swap') graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip') graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) write_roi = daisy.Roi((0, 0, 0), (6, 6, 6)) graph.write_nodes(roi=write_roi) graph.write_edges(roi=write_roi) graph_provider = provider_factory('r') compare_graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] nodes = sorted(list(graph.nodes())) nodes.remove(2) # node 2 has no position and will not be queried nodes.remove(57) # node 57 is outside of the write_roi compare_nodes = compare_graph.nodes(data=True) compare_nodes = [ node_id for node_id, data in compare_nodes if len(data) > 0 ] compare_nodes = sorted(list(compare_nodes)) edges = sorted(list(graph.edges())) edges.remove((2, 42)) # node 2 has no position and will not be queried compare_edges = sorted(list(compare_graph.edges())) self.assertEqual(nodes, compare_nodes) self.assertEqual(edges, compare_edges)
def get_site_fragment_lut(fragments, sites, roi): #Get the fragment IDs of all the sites that are contained in the given ROI sites = list(sites) if len(sites) == 0: logging.info(f"No sites in {roi}, skipping") return None, None logging.info( f"Getting fragment IDs for {len(sites)} synaptic sites in {roi}...") # for a few sites, direct lookup is faster than memory copies if len(sites) >= 15: logging.info("Copying fragments into memory...") fragments = fragments[roi] fragments.materialize() logging.info(f"Getting fragment IDs for synaptic sites in {roi}...") fragment_ids = np.array([ fragments[daisy.Coordinate((site['z'], site['y'], site['x']))] for site in sites ]) site_ids = np.array([site['id'] for site in sites], dtype=np.uint64) fg_mask = fragment_ids != 0 fragment_ids = fragment_ids[fg_mask] site_ids = site_ids[fg_mask] lut = np.array([site_ids, fragment_ids]) return lut, (fg_mask == 0).sum()
def downscale_block(in_array, out_array, factor, block): dims = len(factor) in_data = in_array.to_ndarray(block.read_roi, fill_value=0) in_shape = daisy.Coordinate(in_data.shape[-dims:]) assert in_shape.is_multiple_of(factor) n_channels = len(in_data.shape) - dims if n_channels >= 1: factor = (1, ) * n_channels + factor if in_data.dtype == np.uint64: slices = tuple(slice(k // 2, None, k) for k in factor) out_data = in_data[slices] else: out_data = skimage.measure.block_reduce(in_data, factor, np.mean) try: out_array[block.write_roi] = out_data except Exception: print("Failed to write to %s" % block.write_roi) raise return 0
def parallel_lsd_agglomerate(lsds, fragments, rag_provider, lsd_extractor, block_size, context, num_workers): '''Agglomerate fragments in parallel using only the shape descriptors. Args: lsds (`class:daisy.Array`): An array containing the LSDs. fragments (`class:daisy.Array`): An array containing fragments. rag_provider (`class:SharedRagProvider`): A RAG provider to read nodes from and write found edges to. lsd_extractor (``LsdExtractor``): The local shape descriptor object used to compute the difference between the segmentation and the target LSDs. block_size (``tuple`` of ``int``): The size of the blocks to process in parallel, in world units. context (``tuple`` of ``int``): The context to consider for agglomeration, in world units. num_workers (``int``): The number of parallel workers. Returns: True, if all tasks succeeded. ''' assert fragments.data.dtype == np.uint64 shape = lsds.shape[1:] context = daisy.Coordinate(context) total_roi = lsds.roi.grow(context, context) read_roi = daisy.Roi((0, ) * lsds.roi.dims(), block_size).grow(context, context) write_roi = daisy.Roi((0, ) * lsds.roi.dims(), block_size) return daisy.run_blockwise( total_roi, read_roi, write_roi, lambda b: agglomerate_in_block(lsds, fragments, rag_provider, lsd_extractor, b), lambda b: block_done(b, rag_provider), num_workers=num_workers, read_write_conflict=False, fit='shrink')
def relabel_connected_components(array_in, array_out, block_size, num_workers): '''Relabel connected components in an array in parallel. Args: array_in (``daisy.Array``): The array to relabel. array_out (``daisy.Array``): The array to write to. Should initially be empty (i.e., all zeros). block_size (``daisy.Coordinate``): The size of the blocks to relabel in, in world units. num_workers (``int``): The number of workers to use. ''' block_size = daisy.Coordinate(block_size) segment_blockwise( array_in, array_out, block_size=block_size, context=array_in.voxel_size, num_workers=num_workers, segment_function=label_connected_components)
def overlay_segmentation(db_host, db_name, roi_offset, roi_size, selected_attr, solved_attr, edge_collection, segmentation_container, segmentation_dataset, segmentation_number, voxel_size=(40, 4, 4)): graph_provider = MongoDbGraphProvider(db_name, db_host, directed=False, position_attribute=['z', 'y', 'x'], edges_collection=edge_collection) graph_roi = daisy.Roi(roi_offset, roi_size) segmentation = daisy.open_ds(segmentation_container, segmentation_dataset) intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid( voxel_size) nx_graph = graph_provider.get_graph(intersection_roi, nodes_filter={selected_attr: True}, edges_filter={selected_attr: True}) for node_id, data in nx_graph.nodes(data=True): node_position = daisy.Coordinate((data["z"], data["y"]), data["x"]) nx_graph.nodes[node_id]["segmentation_{}".format( segmentation_number)] = segmentation[node_position] graph_provider.write_nodes(intersection_roi)
def run_test_graph_write_attributes(self, provider_factory): graph_provider = provider_factory('w') graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, comment="without position") graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip='swap') graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip') graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) graph.write_nodes(attributes=['position', 'swip']) graph.write_edges() graph_provider = provider_factory('r') compare_graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] nodes = [] for node, data in graph.nodes(data=True): if node == 2: continue if 'zap' in data: del data['zap'] data['position'] = list(data['position']) nodes.append((node, data)) compare_nodes = compare_graph.nodes(data=True) compare_nodes = [(node_id, data) for node_id, data in compare_nodes if len(data) > 0] self.assertCountEqual(nodes, compare_nodes)
def test_graph_multiple_separate_collections(self): attributes = {'3': ['selected'], '4': ['swip']} graph_provider = self.get_mongo_graph_provider('w', attributes, attributes) roi = daisy.Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] graph.add_node(2, position=(2, 2, 2), swip='swap') graph.add_node(42, position=(1, 1, 1), selected=False, swip='swim') graph.add_node(23, position=(5, 5, 5), selected=True) graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), selected=True) graph.add_edge(42, 23) graph.add_edge(57, 23, selected=True, swip='swap') graph.add_edge(2, 42, selected=True) graph.add_edge(42, 2, swip='swim') graph.write_nodes() graph.write_edges() graph_provider = self.get_mongo_graph_provider('r', attributes, attributes) compare_graph = graph_provider[roi] self.assertFalse('selected' in compare_graph.nodes[2]) self.assertEqual('swap', compare_graph.nodes[2]['swip']) self.assertEqual(False, compare_graph.nodes[42]['selected']) self.assertEqual('swim', compare_graph.nodes[42]['swip']) self.assertFalse('swip' in compare_graph.nodes[57]) self.assertEqual(True, compare_graph.edges[2, 42]['selected']) self.assertFalse('swip' in compare_graph.edges[2, 42]) self.assertFalse('selected' in compare_graph.edges[42, 23]) self.assertEqual('swim', compare_graph.edges[42, 2]['swip'])
def run_test_graph_io(self, provider_factory): graph_provider = provider_factory('w') graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, comment="without position") graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip='swap') graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip') graph.add_edge(42, 23) graph.add_edge(57, 23) graph.add_edge(2, 42) graph.write_nodes() graph.write_edges() graph_provider = provider_factory('r') compare_graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] nodes = sorted(list(graph.nodes())) nodes.remove(2) # node 2 has no position and will not be queried compare_nodes = sorted(list(compare_graph.nodes())) edges = sorted(list(graph.edges())) edges.remove((2, 42)) # node 2 has no position and will not be queried compare_edges = sorted(list(compare_graph.edges())) self.assertEqual(nodes, compare_nodes) self.assertEqual(edges, compare_edges)
def test_graph_read_unbounded_roi(self): graph_provider = self.get_mongo_graph_provider('w') roi = daisy.Roi((0, 0, 0), (10, 10, 10)) unbounded_roi = daisy.Roi((None, None, None), (None, None, None)) graph = graph_provider[roi] graph.add_node(2, position=(2, 2, 2), selected=True, test='test') graph.add_node(42, position=(1, 1, 1), selected=False, test='test2') graph.add_node(23, position=(5, 5, 5), selected=True, test='test2') graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), selected=True, test='test') graph.add_edge(42, 23, selected=False, a=100, b=3) graph.add_edge(57, 23, selected=True, a=100, b=2) graph.add_edge(2, 42, selected=True, a=101, b=3) graph.write_nodes() graph.write_edges() graph_provider = self.get_mongo_graph_provider('r+') limited_graph = graph_provider.get_graph(unbounded_roi, node_attrs=['selected'], edge_attrs=['c']) seen = [] for node, data in limited_graph.nodes(data=True): self.assertFalse('test' in data) self.assertTrue('selected' in data) data['selected'] = True seen.append(node) self.assertCountEqual(seen, [2, 42, 23, 57])
def run_test_graph_connected_components(self, provider_factory): graph_provider = provider_factory('w') graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, comment="without position") graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip='swap') graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip') graph.add_edge(57, 23) graph.add_edge(2, 42) components = graph.get_connected_components() self.assertEqual(len(components), 2) c1, c2 = components n1 = sorted(list(c1.nodes())) n2 = sorted(list(c2.nodes())) compare_n1 = [2, 42] compare_n2 = [23, 57] if 2 in n2: temp = n2 n2 = n1 n1 = temp self.assertCountEqual(n1, compare_n1) self.assertCountEqual(n2, compare_n2)
def test_graph(): graph_provider = daisy.persistence.MongoDbGraphProvider( 'test_daisy_graph', '10.40.4.51', nodes_collection='nodes', edges_collection='edges', mode='w') graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] graph.add_node(2, comment="without position") graph.add_node(42, position=(1, 1, 1)) graph.add_node(23, position=(5, 5, 5), swip='swap') graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip') graph.add_edge(42, 23) for i in range(10000): graph.add_node(i + 100, position=(random.randint(0, 10), random.randint(0, 10), random.randint(0, 10))) start = time.time() graph.write_nodes() graph.write_edges() print("Wrote graph in %.3fs" % (time.time() - start)) start = time.time() graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))] print("Read graph in %.3fs" % (time.time() - start))
def test_graph_separate_collection_simple(self): attributes = {'1': ['selected']} graph_provider = self.get_mongo_graph_provider('w', attributes, attributes) roi = daisy.Roi((0, 0, 0), (10, 10, 10)) graph = graph_provider[roi] graph.add_node(2, position=(2, 2, 2), selected=True) graph.add_node(42, position=(1, 1, 1), selected=False) graph.add_node(23, position=(5, 5, 5), selected=True) graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), selected=True) graph.add_edge(42, 23, selected=False) graph.add_edge(57, 23, selected=True) graph.add_edge(2, 42, selected=True) graph.write_nodes() graph.write_edges() graph_provider = self.get_mongo_graph_provider('r', attributes, attributes) compare_graph = graph_provider[roi] self.assertEqual(True, compare_graph.nodes[2]['selected']) self.assertEqual(False, compare_graph.nodes[42]['selected']) self.assertEqual(True, compare_graph.edges[2, 42]['selected']) self.assertEqual(False, compare_graph.edges[42, 23]['selected'])
def get_crops(self, center_positions, size): """ Args: center_positions (list of tuple of ints): [(z,y,x)] in nm size (tuple of ints): size of the crop, in voxels """ crops = [] size = daisy.Coordinate(size) voxel_size = daisy.Coordinate(self.voxel_size) size_nm = (size * voxel_size) dataset = daisy.open_ds(self.container, self.dataset) dataset_resolution = None try: dataset_resolution = dataset.voxel_size except AttributeError: pass if dataset_resolution is not None: if not np.all(dataset.voxel_size == self.voxel_size): raise ValueError( f"Dataset {dataset} resolution missmatch {dataset.resolution} vs {self.voxel_size}" ) for position in center_positions: position = daisy.Coordinate(tuple(position)) offset_nm = position - ((size / 2) * voxel_size) roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size, mode='closest') if roi.get_shape()[0] != size_nm[0]: roi.set_shape(size_nm) if not dataset.roi.contains(roi): raise Warning( f"Location {position} is not fully contained in dataset") return crops.append(dataset[roi].to_ndarray()) crops_batched = np.stack(crops) crops_batched = crops_batched.astype(np.float32) return crops_batched