示例#1
0
def build_trees_from_mst(emst, edges_u, edges_v, alpha, coordinate_scale,
                         offset, voxel_size):
    trees = nx.DiGraph()
    ndims = len(voxel_size)
    for edge, u, v in zip(np.array(emst), np.array(edges_u),
                          np.array(edges_v)):
        if edge[2] > alpha:
            continue
        pos_u = daisy.Coordinate((0, ) * (3 - ndims) +
                                 tuple((u[-ndims:] / coordinate_scale) +
                                       (offset / voxel_size)))
        pos_v = daisy.Coordinate((0, ) * (3 - ndims) +
                                 tuple((v[-ndims:] / coordinate_scale) +
                                       (offset / voxel_size)))
        if edge[0] not in trees.nodes:
            trees.add_node(edge[0], pos=pos_u)
        else:
            assert trees.nodes[
                edge[0]]["pos"] == pos_u, "locations don't match"
        if edge[1] not in trees.nodes:
            trees.add_node(edge[1], pos=pos_v)
        else:
            assert trees.nodes[
                edge[1]]["pos"] == pos_v, "locations don't match"
        trees.add_edge(edge[0], edge[1], d=edge[2])
    return trees
示例#2
0
def build_trees(node_ids, locations, edges, node_attrs=None, edge_attrs=None):
    if node_attrs is None:
        node_attrs = {}
    if edge_attrs is None:
        edge_attrs = {}

    node_to_index = {n: i for i, n in enumerate(node_ids)}
    trees = nx.Graph()
    pbs = {
        int(node_id): node_location
        for node_id, node_location in zip(node_ids, locations)
    }
    for i, row in enumerate(edges):
        u = node_to_index.get(int(row[0]), -1)
        v = node_to_index.get(int(row[-1]), -1)

        e_attrs = {attr: values[i] for attr, values in edge_attrs.items()}

        if u == -1 or v == -1:
            continue

        pos_u = daisy.Coordinate(tuple(pbs[node_ids[u]]))
        pos_v = daisy.Coordinate(tuple(pbs[node_ids[v]]))

        if node_ids[u] not in trees.nodes:
            u_attrs = {attr: values[u] for attr, values in node_attrs.items()}
            trees.add_node(node_ids[u], location=pos_u, **u_attrs)
        if node_ids[v] not in trees.nodes:
            v_attrs = {attr: values[v] for attr, values in node_attrs.items()}
            trees.add_node(node_ids[v], location=pos_v, **v_attrs)

        trees.add_edge(node_ids[u], node_ids[v], **e_attrs)
    return trees
示例#3
0
def solve(predict_config, worker_config, data_config, graph_config,
          solve_config, num_block_workers, block_size, roi_offset, roi_size,
          context, solve_block, base_dir, experiment, train_number,
          predict_number, graph_number, solve_number, queue,
          singularity_container, mount_dirs, **kwargs):

    source_roi = daisy.Roi(daisy.Coordinate(roi_offset),
                           daisy.Coordinate(roi_size))

    solve_setup_dir = os.path.join(
        os.path.join(base_dir, experiment),
        "04_solve/setup_t{}_p{}_g{}_s{}".format(train_number, predict_number,
                                                graph_number, solve_number))

    block_write_roi = daisy.Roi((0, 0, 0), block_size)
    block_read_roi = block_write_roi.grow(context, context)
    total_roi = source_roi.grow(context, context)

    logger.info("Solving in %s", total_roi)

    daisy.run_blockwise(
        total_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda:
        start_worker(predict_config, worker_config, data_config, graph_config,
                     solve_config, queue, singularity_container, mount_dirs,
                     solve_block, solve_setup_dir),
        num_workers=num_block_workers,
        fit='shrink')

    logger.info("Finished solving, parameters id is %s", solve_number)
示例#4
0
def get_source_roi(data_dir, sample):

    sample_path = os.path.join(data_dir, sample)

    # get absolute paths
    if os.path.isfile(sample_path) or sample.endswith((".zarr", ".n5")):
        sample_dir = os.path.abspath(
            os.path.join(data_dir, os.path.dirname(sample)))
    else:
        sample_dir = os.path.abspath(os.path.join(data_dir, sample))

    if os.path.isfile(os.path.join(sample_dir, 'attributes.json')):

        with open(os.path.join(sample_dir, 'attributes.json'), 'r') as f:
            attributes = json.load(f)
        voxel_size = daisy.Coordinate(attributes['resolution'])
        shape = daisy.Coordinate(attributes['shape'])
        offset = daisy.Coordinate(attributes['offset'])
        source_roi = daisy.Roi(offset, shape * voxel_size)

        return voxel_size, source_roi

    elif os.path.isdir(os.path.join(sample_dir, 'timelapse.zarr')):

        a = daisy.open_ds(os.path.join(sample_dir, 'timelapse.zarr'),
                          'volumes/raw')

        return a.voxel_size, a.roi

    else:

        raise RuntimeError(
            "Can't find attributes.json or timelapse.zarr in %s" % sample_dir)
示例#5
0
def read_data_config(data_config):
    config = configparser.ConfigParser()
    config.read(data_config)

    cfg_dict = {}

    # Data
    cfg_dict["sample"] = config.get("Data", "sample")
    # Don't think I need these
    offset = config.get("Data", "roi_offset")
    size = config.get("Data", "roi_size")
    cfg_dict["roi_offset"] = daisy.Coordinate(
        tuple(int(x) for x in offset.split(", ")) if not offset == "None" else [None] * 3
    )
    cfg_dict["roi_size"] = daisy.Coordinate(
        tuple(int(x) for x in size.split(", ")) if not size == "None" else [None] * 3
    )
    cfg_dict["roi"] = daisy.Roi(cfg_dict["roi_offset"], cfg_dict["roi_size"])
    cfg_dict["location_attr"] = config.get("Data", "location_attr")
    cfg_dict["penalty_attr"] = config.get("Data", "penalty_attr")
    cfg_dict["target_edge_len"] = int(config.get("Data", "target_edge_len"))

    # Database
    cfg_dict[
        "consensus_db"
    ] = f"mouselight-{cfg_dict['sample']}-{config.get('Data', 'consensus_db')}"
    cfg_dict[
        "subdivided_db"
    ] = f"mouselight-{cfg_dict['sample']}-{config.get('Data', 'subdivided_db')}"
    cfg_dict["db_host"] = config.get("Data", "db_host")

    return cfg_dict
示例#6
0
def build_trees(node_ids, locations, edges, voxel_size):
    trees = nx.DiGraph()
    pbs = {
        int(node_id): node_location
        for node_id, node_location in zip(node_ids, locations)
    }
    for row in edges:
        u = int(row[0])
        v = int(row[-1])

        if u == -1 or v == -1:
            continue

        pos_u = daisy.Coordinate(tuple(pbs[u])) / voxel_size
        pos_v = daisy.Coordinate(tuple(pbs[v])) / voxel_size

        if u not in trees.nodes:
            trees.add_node(u, pos=pos_u)
        else:
            assert trees.nodes[u]["pos"] == pos_u, "locations don't match"
        if v not in trees.nodes:
            trees.add_node(v, pos=pos_v)
        else:
            assert trees.nodes[v]["pos"] == pos_v, "locations don't match"

        trees.add_edge(u, v, d=np.linalg.norm(pos_u - pos_v))
    return trees
示例#7
0
 def test_load_klb(self):
     self.write_test_klb()
     data = daisy.open_ds(self.klb_file, None)
     self.assertEqual(data.roi.get_offset(), daisy.Coordinate((0, 0, 0, 0)))
     self.assertEqual(data.roi.get_shape(), daisy.Coordinate(
         (1, 20, 10, 10)))
     self.assertEqual(data.voxel_size, daisy.Coordinate((1, 2, 1, 1)))
     self.remove_test_klb()
def extract_edges_blockwise(db_host,
                            db_name,
                            sample,
                            edge_move_threshold,
                            block_size,
                            num_workers,
                            frames=None,
                            frame_context=1,
                            data_dir='../01_data',
                            use_pv_distance=False,
                            **kwargs):

    voxel_size, source_roi = get_source_roi(data_dir, sample)

    # limit to specific frames, if given
    if frames:
        begin, end = frames
        begin -= frame_context
        end += frame_context
        crop_roi = daisy.Roi((begin, None, None, None),
                             (end - begin, None, None, None))
        source_roi = source_roi.intersect(crop_roi)

    # block size in world units
    block_write_roi = daisy.Roi((0, ) * 4, daisy.Coordinate(block_size))

    pos_context = daisy.Coordinate((0, ) + (edge_move_threshold, ) * 3)
    neg_context = daisy.Coordinate((1, ) + (edge_move_threshold, ) * 3)
    logger.debug("Set neg context to %s", neg_context)

    input_roi = source_roi.grow(neg_context, pos_context)
    block_read_roi = block_write_roi.grow(neg_context, pos_context)

    print("Following ROIs in world units:")
    print("Input ROI       = %s" % input_roi)
    print("Block read  ROI = %s" % block_read_roi)
    print("Block write ROI = %s" % block_write_roi)
    print("Output ROI      = %s" % source_roi)

    print("Starting block-wise processing...")

    # process block-wise
    daisy.run_blockwise(input_roi,
                        block_read_roi,
                        block_write_roi,
                        process_function=lambda b: extract_edges_in_block(
                            db_name,
                            db_host,
                            edge_move_threshold,
                            b,
                            use_pv_distance=use_pv_distance),
                        check_function=lambda b: check_function(
                            b, 'extract_edges', db_name, db_host),
                        num_workers=num_workers,
                        processes=True,
                        read_write_conflict=False,
                        fit='shrink')
def ingest_consensus(sample: str, consensus_dir: Path, transform_file: Path,
                     url: str):
    """
    Storing data in MongoDB using psuedo world coords (1,.3,.3) microns rather than the
    slightly off floats found in the transform.txt file.
    """

    mongo_graph_provider = daisy.persistence.MongoDbGraphProvider(
        f"mouselight-{sample}-consensus", url, directed=True, mode="w")
    graph = mongo_graph_provider.get_graph(
        daisy.Roi(daisy.Coordinate([None, None, None]),
                  daisy.Coordinate([None, None, None])))
    consensus_graphs = []
    for consensus_neuron in tqdm(consensus_dir.iterdir(),
                                 "Consensus neurons: "):
        if (not consensus_neuron.is_dir()
                or not (consensus_neuron / "consensus.swc").exists()):
            continue
        consensus_graph = parse_consensus(
            consensus_neuron / "consensus.swc",
            consensus_neuron / "dendrite.swc",
            transform,
            offset=np.array([0, 0, 0]),
            resolution=np.array([300, 300, 1000]),
            transpose=[2, 1, 0],
        )
        for node in consensus_graph.nodes:
            consensus_graph.nodes[node]["position"] = consensus_graph.nodes[
                node]["location"].tolist()
            del consensus_graph.nodes[node]["location"]
        consensus_graphs.append(consensus_graph)
    logger.info("Consolidating consensus graphs!")
    consensus_graph = nx.disjoint_union_all(consensus_graphs)

    data = {}
    for node_id, attrs in consensus_graph.nodes.items():
        node_id = int(np.int64(node_id))
        node_ids = data.setdefault("id", [])
        node_ids.append(node_id)
        for key, value in attrs.items():
            dlist = data.setdefault(key, [])
            dlist.append(value)

    logger.info(
        f"Writing {len(consensus_graph.nodes)} nodes and {len(consensus_graph.edges)} edges!"
    )

    bulk_write_nodes(url, f"mouselight-{sample}-consensus", "nodes", data)
    bulk_write_edges(
        url,
        f"mouselight-{sample}-consensus",
        "edges",
        ("u", "v"),
        list(consensus_graph.edges),
        True,
    )
示例#10
0
    def __init__(self, positions, dataset, size):

        self.dataset = dataset
        self.data = dataset.open_daisy()

        self.voxel_size = daisy.Coordinate(dataset.voxel_size)
        self.size = daisy.Coordinate(size)
        self.size_nm = self.size * self.voxel_size

        self.positions = self.filter_positions(positions)
        self.transform = None
示例#11
0
 def test_overwrite_attrs(self):
     self.write_test_klb()
     data = daisy.open_ds(self.klb_file,
                          None,
                          attr_filename=self.attrs_file)
     self.assertEqual(data.roi.get_offset(),
                      daisy.Coordinate((0, 100, 10, 0)))
     self.assertEqual(data.roi.get_shape(), daisy.Coordinate(
         (1, 10, 20, 20)))
     self.assertEqual(data.voxel_size, daisy.Coordinate((1, 1, 2, 2)))
     self.remove_test_klb()
示例#12
0
def get_raw(locs, size, voxel_size, data_container, data_set):
    """
    Get raw crops from the specified
    dataset.

    locs(``list of tuple of ints``):

        list of centers of location of interest

    size(``tuple of ints``):
        
        size of cropout in voxel

    voxel_size(``tuple of ints``):

        size of a voxel

    data_container(``string``):

        path to data container (e.g. zarr file)

    data_set(``string``):

        corresponding data_set name, (e.g. raw)

    """

    raw = []
    size = daisy.Coordinate(size)
    voxel_size = daisy.Coordinate(voxel_size)
    size_nm = (size * voxel_size)
    dataset = daisy.open_ds(data_container, data_set)

    for loc in locs:
        offset_nm = loc - (size / 2 * voxel_size)
        roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size,
                                                         mode='closest')

        if roi.get_shape()[0] != size[0]:
            roi.set_shape(size_nm)

        if not dataset.roi.contains(roi):
            logger.WARNING("Location %s is not fully contained in dataset" %
                           loc)
            return

        raw.append(dataset[roi].to_ndarray())

    raw = np.stack(raw)
    raw = raw.astype(np.float32)
    raw_normalized = raw / 255.0
    raw_normalized = raw_normalized * 2.0 - 1.0
    return raw, raw_normalized
示例#13
0
def extract_edges(
        db_host,
        db_name,
        soft_mask_container,
        soft_mask_dataset,
        roi_offset,
        roi_size,
        distance_threshold,
        block_size,
        num_block_workers,
        graph_number,
        **kwargs):

    # Define Rois:
    source_roi = daisy.Roi(roi_offset, roi_size)
    block_write_roi = daisy.Roi(
        (0,) * 3,
        daisy.Coordinate(block_size))

    pos_context = daisy.Coordinate((distance_threshold,)*3)
    neg_context = daisy.Coordinate((distance_threshold,)*3)
    logger.debug("Set pos context to %s", pos_context)
    logger.debug("Set neg context to %s", neg_context)

    input_roi = source_roi.grow(neg_context, pos_context)
    block_read_roi = block_write_roi.grow(neg_context, pos_context)

    logger.info("Following ROIs in world units:")
    logger.info("Input ROI       = %s" % input_roi)
    logger.info("Block read  ROI = %s" % block_read_roi)
    logger.info("Block write ROI = %s" % block_write_roi)
    logger.info("Output ROI      = %s" % source_roi)

    logger.info("Starting block-wise processing...")

    # process block-wise
    daisy.run_blockwise(
        input_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda b: extract_edges_in_block(
            db_name,
            db_host,
            soft_mask_container,
            soft_mask_dataset,
            distance_threshold,
            graph_number,
            b),
        num_workers=num_block_workers,
        processes=True,
        read_write_conflict=False,
        fit='shrink')
示例#14
0
    def test_none(self):

        a = daisy.Coordinate((None, 1, 2))
        b = daisy.Coordinate((3, 4, None))

        assert a + b == (None, 5, None)
        assert a - b == (None, -3, None)
        assert a / b == (None, 0, None)
        assert a // b == (None, 0, None)
        assert b / a == (None, 4, None)
        assert b // a == (None, 4, None)
        assert abs(a) == (None, 1, 2)
        assert abs(-a) == (None, 1, 2)
示例#15
0
def visualize_npy(npy_file: Path, voxel_size):
    voxel_size = daisy.Coordinate(voxel_size)

    viewer = neuroglancer.Viewer()
    with viewer.txn() as s:
        v = np.load(npy_file)
        m = daisy.Array(
            v,
            daisy.Roi(daisy.Coordinate([0, 0, 0]), daisy.Coordinate(v.shape)),
            daisy.Coordinate([1, 1, 1]),
        )
        add_layer(s, m, f"npy array")
    print(viewer)
    input("Hit ENTER to quit!")
示例#16
0
def get_raw_parallel(locs, size, voxel_size, data_container, data_set):
    """
    Get raw crops from the specified
    dataset.

    locs(``list of tuple of ints``):

        list of centers of location of interest

    size(``tuple of ints``):
        
        size of cropout in voxel

    voxel_size(``tuple of ints``):

        size of a voxel

    data_container(``string``):

        path to data container (e.g. zarr file)

    data_set(``string``):

        corresponding data_set name, (e.g. raw)

    """

    pool = Pool(processes=len(locs))
    raw = []
    size = daisy.Coordinate(size)
    voxel_size = daisy.Coordinate(voxel_size)
    size_nm = (size * voxel_size)
    dataset = daisy.open_ds(data_container, data_set)

    raw_workers = [
        pool.apply_async(fetch_from_ds,
                         (dataset, loc, voxel_size, size, size_nm))
        for loc in locs
    ]
    raw = [w.get(timeout=60) for w in raw_workers]
    pool.close()
    pool.join()

    raw = np.stack(raw)
    raw = raw.astype(np.float32)
    raw_normalized = raw / 255.0
    raw_normalized = raw_normalized * 2.0 - 1.0
    return raw, raw_normalized
示例#17
0
    def run_test_graph_write_roi(self, provider_factory):

        graph_provider = provider_factory('w')
        graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        graph.add_node(2, comment="without position")
        graph.add_node(42, position=(1, 1, 1))
        graph.add_node(23, position=(5, 5, 5), swip='swap')
        graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip')
        graph.add_edge(42, 23)
        graph.add_edge(57, 23)
        graph.add_edge(2, 42)

        write_roi = daisy.Roi((0, 0, 0), (6, 6, 6))
        graph.write_nodes(roi=write_roi)
        graph.write_edges(roi=write_roi)

        graph_provider = provider_factory('r')
        compare_graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        nodes = sorted(list(graph.nodes()))
        nodes.remove(2)  # node 2 has no position and will not be queried
        nodes.remove(57)  # node 57 is outside of the write_roi
        compare_nodes = compare_graph.nodes(data=True)
        compare_nodes = [
            node_id for node_id, data in compare_nodes if len(data) > 0
        ]
        compare_nodes = sorted(list(compare_nodes))
        edges = sorted(list(graph.edges()))
        edges.remove((2, 42))  # node 2 has no position and will not be queried
        compare_edges = sorted(list(compare_graph.edges()))

        self.assertEqual(nodes, compare_nodes)
        self.assertEqual(edges, compare_edges)
示例#18
0
def get_site_fragment_lut(fragments, sites, roi):
    #Get the fragment IDs of all the sites that are contained in the given ROI

    sites = list(sites)

    if len(sites) == 0:
        logging.info(f"No sites in {roi}, skipping")
        return None, None

    logging.info(
        f"Getting fragment IDs for {len(sites)} synaptic sites in {roi}...")

    # for a few sites, direct lookup is faster than memory copies
    if len(sites) >= 15:

        logging.info("Copying fragments into memory...")
        fragments = fragments[roi]
        fragments.materialize()

    logging.info(f"Getting fragment IDs for synaptic sites in {roi}...")

    fragment_ids = np.array([
        fragments[daisy.Coordinate((site['z'], site['y'], site['x']))]
        for site in sites
    ])
    site_ids = np.array([site['id'] for site in sites], dtype=np.uint64)

    fg_mask = fragment_ids != 0
    fragment_ids = fragment_ids[fg_mask]
    site_ids = site_ids[fg_mask]

    lut = np.array([site_ids, fragment_ids])

    return lut, (fg_mask == 0).sum()
示例#19
0
def downscale_block(in_array, out_array, factor, block):

    dims = len(factor)
    in_data = in_array.to_ndarray(block.read_roi, fill_value=0)

    in_shape = daisy.Coordinate(in_data.shape[-dims:])
    assert in_shape.is_multiple_of(factor)

    n_channels = len(in_data.shape) - dims
    if n_channels >= 1:
        factor = (1, ) * n_channels + factor

    if in_data.dtype == np.uint64:
        slices = tuple(slice(k // 2, None, k) for k in factor)
        out_data = in_data[slices]
    else:
        out_data = skimage.measure.block_reduce(in_data, factor, np.mean)

    try:
        out_array[block.write_roi] = out_data
    except Exception:
        print("Failed to write to %s" % block.write_roi)
        raise

    return 0
示例#20
0
def parallel_lsd_agglomerate(lsds, fragments, rag_provider, lsd_extractor,
                             block_size, context, num_workers):
    '''Agglomerate fragments in parallel using only the shape descriptors.

    Args:

        lsds (`class:daisy.Array`):

            An array containing the LSDs.

        fragments (`class:daisy.Array`):

            An array containing fragments.

        rag_provider (`class:SharedRagProvider`):

            A RAG provider to read nodes from and write found edges to.

        lsd_extractor (``LsdExtractor``):

            The local shape descriptor object used to compute the difference
            between the segmentation and the target LSDs.

        block_size (``tuple`` of ``int``):

            The size of the blocks to process in parallel, in world units.

        context (``tuple`` of ``int``):

            The context to consider for agglomeration, in world units.

        num_workers (``int``):

            The number of parallel workers.

    Returns:

        True, if all tasks succeeded.
    '''

    assert fragments.data.dtype == np.uint64

    shape = lsds.shape[1:]
    context = daisy.Coordinate(context)

    total_roi = lsds.roi.grow(context, context)
    read_roi = daisy.Roi((0, ) * lsds.roi.dims(),
                         block_size).grow(context, context)
    write_roi = daisy.Roi((0, ) * lsds.roi.dims(), block_size)

    return daisy.run_blockwise(
        total_roi,
        read_roi,
        write_roi,
        lambda b: agglomerate_in_block(lsds, fragments, rag_provider,
                                       lsd_extractor, b),
        lambda b: block_done(b, rag_provider),
        num_workers=num_workers,
        read_write_conflict=False,
        fit='shrink')
示例#21
0
def relabel_connected_components(array_in, array_out, block_size, num_workers):
    '''Relabel connected components in an array in parallel.

    Args:

        array_in (``daisy.Array``):

            The array to relabel.

        array_out (``daisy.Array``):

            The array to write to. Should initially be empty (i.e., all zeros).

        block_size (``daisy.Coordinate``):

            The size of the blocks to relabel in, in world units.

        num_workers (``int``):

            The number of workers to use.
    '''

    block_size = daisy.Coordinate(block_size)

    segment_blockwise(
        array_in,
        array_out,
        block_size=block_size,
        context=array_in.voxel_size,
        num_workers=num_workers,
        segment_function=label_connected_components)
示例#22
0
def overlay_segmentation(db_host,
                         db_name,
                         roi_offset,
                         roi_size,
                         selected_attr,
                         solved_attr,
                         edge_collection,
                         segmentation_container,
                         segmentation_dataset,
                         segmentation_number,
                         voxel_size=(40, 4, 4)):

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    graph_roi = daisy.Roi(roi_offset, roi_size)

    segmentation = daisy.open_ds(segmentation_container, segmentation_dataset)
    intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid(
        voxel_size)

    nx_graph = graph_provider.get_graph(intersection_roi,
                                        nodes_filter={selected_attr: True},
                                        edges_filter={selected_attr: True})

    for node_id, data in nx_graph.nodes(data=True):
        node_position = daisy.Coordinate((data["z"], data["y"]), data["x"])
        nx_graph.nodes[node_id]["segmentation_{}".format(
            segmentation_number)] = segmentation[node_position]

    graph_provider.write_nodes(intersection_roi)
示例#23
0
    def run_test_graph_write_attributes(self, provider_factory):

        graph_provider = provider_factory('w')
        graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        graph.add_node(2, comment="without position")
        graph.add_node(42, position=(1, 1, 1))
        graph.add_node(23, position=(5, 5, 5), swip='swap')
        graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip')
        graph.add_edge(42, 23)
        graph.add_edge(57, 23)
        graph.add_edge(2, 42)

        graph.write_nodes(attributes=['position', 'swip'])
        graph.write_edges()

        graph_provider = provider_factory('r')
        compare_graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        nodes = []
        for node, data in graph.nodes(data=True):
            if node == 2:
                continue
            if 'zap' in data:
                del data['zap']
            data['position'] = list(data['position'])
            nodes.append((node, data))

        compare_nodes = compare_graph.nodes(data=True)
        compare_nodes = [(node_id, data) for node_id, data in compare_nodes
                         if len(data) > 0]
        self.assertCountEqual(nodes, compare_nodes)
    def test_graph_multiple_separate_collections(self):
        attributes = {'3': ['selected'], '4': ['swip']}
        graph_provider = self.get_mongo_graph_provider('w', attributes,
                                                       attributes)
        roi = daisy.Roi((0, 0, 0), (10, 10, 10))
        graph = graph_provider[roi]

        graph.add_node(2, position=(2, 2, 2), swip='swap')
        graph.add_node(42, position=(1, 1, 1), selected=False, swip='swim')
        graph.add_node(23, position=(5, 5, 5), selected=True)
        graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), selected=True)
        graph.add_edge(42, 23)
        graph.add_edge(57, 23, selected=True, swip='swap')
        graph.add_edge(2, 42, selected=True)
        graph.add_edge(42, 2, swip='swim')

        graph.write_nodes()
        graph.write_edges()

        graph_provider = self.get_mongo_graph_provider('r', attributes,
                                                       attributes)
        compare_graph = graph_provider[roi]

        self.assertFalse('selected' in compare_graph.nodes[2])
        self.assertEqual('swap', compare_graph.nodes[2]['swip'])
        self.assertEqual(False, compare_graph.nodes[42]['selected'])
        self.assertEqual('swim', compare_graph.nodes[42]['swip'])
        self.assertFalse('swip' in compare_graph.nodes[57])
        self.assertEqual(True, compare_graph.edges[2, 42]['selected'])
        self.assertFalse('swip' in compare_graph.edges[2, 42])
        self.assertFalse('selected' in compare_graph.edges[42, 23])
        self.assertEqual('swim', compare_graph.edges[42, 2]['swip'])
示例#25
0
    def run_test_graph_io(self, provider_factory):

        graph_provider = provider_factory('w')

        graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        graph.add_node(2, comment="without position")
        graph.add_node(42, position=(1, 1, 1))
        graph.add_node(23, position=(5, 5, 5), swip='swap')
        graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip')
        graph.add_edge(42, 23)
        graph.add_edge(57, 23)
        graph.add_edge(2, 42)

        graph.write_nodes()
        graph.write_edges()

        graph_provider = provider_factory('r')
        compare_graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        nodes = sorted(list(graph.nodes()))
        nodes.remove(2)  # node 2 has no position and will not be queried
        compare_nodes = sorted(list(compare_graph.nodes()))

        edges = sorted(list(graph.edges()))
        edges.remove((2, 42))  # node 2 has no position and will not be queried
        compare_edges = sorted(list(compare_graph.edges()))

        self.assertEqual(nodes, compare_nodes)
        self.assertEqual(edges, compare_edges)
    def test_graph_read_unbounded_roi(self):
        graph_provider = self.get_mongo_graph_provider('w')
        roi = daisy.Roi((0, 0, 0), (10, 10, 10))
        unbounded_roi = daisy.Roi((None, None, None), (None, None, None))

        graph = graph_provider[roi]

        graph.add_node(2, position=(2, 2, 2), selected=True, test='test')
        graph.add_node(42, position=(1, 1, 1), selected=False, test='test2')
        graph.add_node(23, position=(5, 5, 5), selected=True, test='test2')
        graph.add_node(57,
                       position=daisy.Coordinate((7, 7, 7)),
                       selected=True,
                       test='test')

        graph.add_edge(42, 23, selected=False, a=100, b=3)
        graph.add_edge(57, 23, selected=True, a=100, b=2)
        graph.add_edge(2, 42, selected=True, a=101, b=3)

        graph.write_nodes()
        graph.write_edges()

        graph_provider = self.get_mongo_graph_provider('r+')
        limited_graph = graph_provider.get_graph(unbounded_roi,
                                                 node_attrs=['selected'],
                                                 edge_attrs=['c'])

        seen = []
        for node, data in limited_graph.nodes(data=True):
            self.assertFalse('test' in data)
            self.assertTrue('selected' in data)
            data['selected'] = True
            seen.append(node)

        self.assertCountEqual(seen, [2, 42, 23, 57])
示例#27
0
    def run_test_graph_connected_components(self, provider_factory):

        graph_provider = provider_factory('w')
        graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

        graph.add_node(2, comment="without position")
        graph.add_node(42, position=(1, 1, 1))
        graph.add_node(23, position=(5, 5, 5), swip='swap')
        graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip')
        graph.add_edge(57, 23)
        graph.add_edge(2, 42)

        components = graph.get_connected_components()
        self.assertEqual(len(components), 2)
        c1, c2 = components
        n1 = sorted(list(c1.nodes()))
        n2 = sorted(list(c2.nodes()))

        compare_n1 = [2, 42]
        compare_n2 = [23, 57]

        if 2 in n2:
            temp = n2
            n2 = n1
            n1 = temp

        self.assertCountEqual(n1, compare_n1)
        self.assertCountEqual(n2, compare_n2)
示例#28
0
def test_graph():

    graph_provider = daisy.persistence.MongoDbGraphProvider(
        'test_daisy_graph',
        '10.40.4.51',
        nodes_collection='nodes',
        edges_collection='edges',
        mode='w')

    graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]

    graph.add_node(2, comment="without position")
    graph.add_node(42, position=(1, 1, 1))
    graph.add_node(23, position=(5, 5, 5), swip='swap')
    graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), zap='zip')
    graph.add_edge(42, 23)

    for i in range(10000):
        graph.add_node(i + 100,
                       position=(random.randint(0, 10), random.randint(0, 10),
                                 random.randint(0, 10)))

    start = time.time()
    graph.write_nodes()
    graph.write_edges()
    print("Wrote graph in %.3fs" % (time.time() - start))

    start = time.time()
    graph = graph_provider[daisy.Roi((0, 0, 0), (10, 10, 10))]
    print("Read graph in %.3fs" % (time.time() - start))
    def test_graph_separate_collection_simple(self):
        attributes = {'1': ['selected']}
        graph_provider = self.get_mongo_graph_provider('w', attributes,
                                                       attributes)
        roi = daisy.Roi((0, 0, 0), (10, 10, 10))
        graph = graph_provider[roi]

        graph.add_node(2, position=(2, 2, 2), selected=True)
        graph.add_node(42, position=(1, 1, 1), selected=False)
        graph.add_node(23, position=(5, 5, 5), selected=True)
        graph.add_node(57, position=daisy.Coordinate((7, 7, 7)), selected=True)
        graph.add_edge(42, 23, selected=False)
        graph.add_edge(57, 23, selected=True)
        graph.add_edge(2, 42, selected=True)

        graph.write_nodes()
        graph.write_edges()

        graph_provider = self.get_mongo_graph_provider('r', attributes,
                                                       attributes)
        compare_graph = graph_provider[roi]

        self.assertEqual(True, compare_graph.nodes[2]['selected'])
        self.assertEqual(False, compare_graph.nodes[42]['selected'])
        self.assertEqual(True, compare_graph.edges[2, 42]['selected'])
        self.assertEqual(False, compare_graph.edges[42, 23]['selected'])
示例#30
0
    def get_crops(self, center_positions, size):
        """
        Args:

            center_positions (list of tuple of ints): [(z,y,x)] in nm

            size (tuple of ints): size of the crop, in voxels
        """

        crops = []
        size = daisy.Coordinate(size)
        voxel_size = daisy.Coordinate(self.voxel_size)
        size_nm = (size * voxel_size)
        dataset = daisy.open_ds(self.container, self.dataset)

        dataset_resolution = None
        try:
            dataset_resolution = dataset.voxel_size
        except AttributeError:
            pass

        if dataset_resolution is not None:
            if not np.all(dataset.voxel_size == self.voxel_size):
                raise ValueError(
                    f"Dataset {dataset} resolution missmatch {dataset.resolution} vs {self.voxel_size}"
                )

        for position in center_positions:
            position = daisy.Coordinate(tuple(position))
            offset_nm = position - ((size / 2) * voxel_size)
            roi = daisy.Roi(offset_nm, size_nm).snap_to_grid(voxel_size,
                                                             mode='closest')

            if roi.get_shape()[0] != size_nm[0]:
                roi.set_shape(size_nm)

            if not dataset.roi.contains(roi):
                raise Warning(
                    f"Location {position} is not fully contained in dataset")
                return

            crops.append(dataset[roi].to_ndarray())

        crops_batched = np.stack(crops)
        crops_batched = crops_batched.astype(np.float32)

        return crops_batched