Exemplo n.º 1
0
    def resolve(cls, items, primary, joins={}):
        ctx = Context(cls.impls)
        returns = []
        rix = {}

        for i, p in enumerate(primary):
            bag = ctx[p.binds]

            returns.append(bag)
            rix[p.binds.general] = 0, i

        exps = defaultdict(set)

        for r in returns:
            if isinstance(r, InnerBag):
                exps[r.ancher].add(r)

        for a, bs in exps.items():
            exp = a.map(len)

            returns = [(db.zip(exp, r).map(lambda n, c: [c] * n)
                        if r not in bs else r.outer).concat()
                       for r in returns]

        base = db.zip(*returns)

        if joins:
            base = base.map(lambda x: (x,))
        else:
            return base

        for n, (p, (ks, js)) in enumerate(joins.items()):
            for m, k in enumerate(ks):
                rix[k.item.general] = n + 1, m + 1

            add = cls.resolve(None, js)

            def fltr(i):
                return lambda ii: ii[0][0][i] == ii[1][0]

            def fltn():
                return lambda ii: ii[0] + (ii[1],)

            i = primary.index(p)
            base = (base
                    .product(add)
                    .filter(fltr(i))
                    .map(fltn()))

        ri = [rix[i.item.general] for i in items]

        return base.map(lambda ii: tuple(ii[n][m] for n, m in ri))
Exemplo n.º 2
0
def input_to_rowmatrix(raw_rdd, indices, norm):
    """
    Utility function for reading the matrix data
    """
    p_and_n = functools.partial(parse_and_normalize, norm = norm)
    numpy_rdd = db.zip(indices, raw_rdd).map(lambda x: (x[0], p_and_n(x[1])))
    return numpy_rdd
Exemplo n.º 3
0
def test_partitions_are_coerced_to_lists():
    # https://github.com/dask/dask/issues/6906
    A = db.from_sequence([[1, 2], [3, 4, 5], [6], [7]])
    B = db.from_sequence(["a", "b", "c", "d"])

    a = random.choices(A.flatten(), k=B.count().compute()).repartition(4)

    C = db.zip(B, a).compute()
    assert len(C) == 4
Exemplo n.º 4
0
def _get_rays_d(lengths, stepSize, start_positions, scaled_look_vecs, Nproc=2):
    import dask.bag as db
    L = db.from_sequence(lengths)
    S = db.from_sequence(start_positions)
    Sv = db.from_sequence(scaled_look_vecs)
    Ss = db.from_sequence([stepSize] * len(lengths))

    # setup for multiprocessing
    data = db.zip(L, S, Sv, Ss)

    positions_l = db.map(helper, data)
    return positions_l.compute()
Exemplo n.º 5
0
def run(src_dir, dst_dir):
    client = get_client()

    # load data
    tiff_paths = find_src_files(src_dir, "tif")
    raw_data = tiff_paths.map(read_tiff)

    # create destination
    create_dst_dir(dst_dir)

    # # downsample
    # bin4_data = raw_data.map(partial(downsample_naive, ratio=(1, 4, 4)))
    # bin4_data = bin4_data.map(da.rechunk)
    # bin4_data = client.persist(bin4_data)
    #
    # logger.info("downsampling")
    # progress(bin4_data)

    logger.info("persist data on cluster")
    bin4_data = client.persist(raw_data, priority=-10)
    progress(bin4_data)

    # save intermediate result
    zarr_paths = tiff_paths.map(partial(build_zarr_path, dst_dir))
    name_data = db.zip(zarr_paths, bin4_data)
    futures = name_data.starmap(write_zarr, path="raw")

    logger.info("save as zarr")
    future = client.compute(futures, priority=10)
    progress(future)

    # convert to h5 for ingestion
    h5_paths = zarr_paths.map(partial(build_h5_path, dst_dir))
    src_dst = db.zip(zarr_paths, h5_paths)
    futures = src_dst.starmap(convert_hdf5)

    logger.info("convert zarr to h5")
    future = client.compute(futures, priority=20)
    progress(future)
Exemplo n.º 6
0
Arquivo: bags.py Projeto: Jxt1/arlo
def residual_vis_bag(vis_bag, model_vis_bag):
    """Calculate residual visibility

    Call directly - don't use via bag.map
    
    :param vis_bag: Bag containing visibilities
    :param model_image_bag: Model images, one per visibility in vis_bag
    :param kwargs:
    :return:
    """
    
    def subtract_vis_zip(vis_zip_bag):
        return subtract_visibility(vis_zip_bag[0], vis_zip_bag[1])
    
    return bag.zip(vis_bag, model_vis_bag).map(subtract_vis_zip)
def restore_bag(comp_bag, psf_bag, res_bag, **kwargs):
    """ Restore a bag of images to obtain a bag of restored images

    Call directly - don't use via bag.map
    
    :param dirty_bag:
    :param psf_bag:
    :param kwargs:
    :return: Bag of Images
    """
    def restore(cpr_zip, **kwargs):
        # The comp is just an Image, while the dirty and psf are actually (Image, weight) tuples.
        return restore_cube(cpr_zip[0], cpr_zip[1][0], cpr_zip[2][0], **kwargs)

    return bag.zip(comp_bag, psf_bag, res_bag).map(restore, **kwargs)
def deconvolve_bag(dirty_bag, psf_bag, model_bag, **kwargs):
    """ Deconvolve a bag of images to obtain a bag of models
    
    Call directly - don't use via bag.map
    
    :param dirty_bag:
    :param psf_bag:
    :param kwargs:
    :return: Bag of Images
    """
    def deconvolve(dp_zip, **kwargs):
        # The dirty and psf are actually (Image, weight) tuples.
        result = deconvolve_cube(dp_zip[0][0], dp_zip[1][0], **kwargs)
        return result[0]

    # We zip up the dirty and psf bags and call the deconvolve adapter
    return bag.zip(dirty_bag, psf_bag).map(deconvolve, **kwargs)
Exemplo n.º 9
0
def test_zip(npartitions, hi=1000):
    evens = db.from_sequence(range(0, hi, 2), npartitions=npartitions)
    odds = db.from_sequence(range(1, hi, 2), npartitions=npartitions)
    pairs = db.zip(evens, odds)
    assert pairs.npartitions == npartitions
    assert list(pairs) == list(zip(range(0, hi, 2), range(1, hi, 2)))
Exemplo n.º 10
0
def test_zip(npartitions, hi=1000):
    evens = db.from_sequence(range(0, hi, 2), npartitions=npartitions)
    odds = db.from_sequence(range(1, hi, 2), npartitions=npartitions)
    pairs = db.zip(evens, odds)
    assert pairs.npartitions == npartitions
    assert list(pairs) == list(zip(range(0, hi, 2), range(1, hi, 2)))
Exemplo n.º 11
0
    def execute(self):
        self._init_service()
        mgr_client = self.mgr_client

        options = self.config["stitchedmeshes"]

        server, uuid, instance = self.input_service.base_service.instance_triple
        is_supervoxels = self.input_service.base_service.supervoxels
        bodies = load_body_list(options["bodies"], is_supervoxels)

        logger.info(f"Input is {len(bodies)} bodies")
        os.makedirs(options["output-directory"], exist_ok=True)

        def make_bricks(coord_and_block):
            coord_zyx, block_vol = coord_and_block
            logical_box = np.array((coord_zyx, coord_zyx + block_vol.shape))
            return Brick(logical_box,
                         logical_box,
                         block_vol,
                         location_id=(logical_box // 64))

        rescale = (2**options["scale"]) * options["extra-rescale"]

        def create_brick_mesh(brick):
            mesh = Mesh.from_binary_vol(brick.volume, brick.physical_box)
            if rescale != 1.0:
                mesh.vertices_zyx *= rescale
            return mesh

        def create_combined_mesh(meshes):
            mesh = concatenate_meshes(meshes, False)
            if options["stitch"]:
                mesh.stitch_adjacent_faces(drop_unused_vertices=True,
                                           drop_duplicate_faces=True)
            mesh.laplacian_smooth(options["smoothing-iterations"])
            mesh.simplify(options["decimation-fraction"], in_memory=True)
            return mesh

        in_flight = 0

        # Support synchronous testing with a fake 'as_completed' object
        if hasattr(self.client, 'DEBUG'):
            result_futures = as_completed_synchronous()
        else:
            result_futures = as_completed()

        def pop_result():
            nonlocal in_flight
            r = next(result_futures)
            in_flight -= 1

            try:
                return r.result()
            except Exception as ex:
                if options["error-mode"] == "raise":
                    raise
                body = int(r.key)
                return (body, 0, 'error', str(ex))

        USER = getpass.getuser()
        results = []
        try:
            for i, body in enumerate(bodies):
                logger.info(f"Mesh #{i}: Body {body}: Starting")

                def fetch_sparsevol():
                    with mgr_client.access_context(server, True, 1, 0):
                        ns = default_node_service(server, uuid,
                                                  'flyemflows-stitchedmeshes',
                                                  USER)
                        coords_zyx, blocks = ns.get_sparselabelmask(
                            body, instance, options["scale"], is_supervoxels)
                        return list(coords_zyx.copy()), list(blocks.copy())

                # This leaves all blocks and bricks in a single partition,
                # but we're about to do a shuffle anyway when the bricks are realigned.
                coords, blocks = delayed(fetch_sparsevol, nout=2)()
                coords, blocks = db.from_delayed(coords), db.from_delayed(
                    blocks)
                bricks = db.zip(coords, blocks).map(make_bricks)

                mesh_grid = Grid((64, 64, 64), halo=options["block-halo"])
                wall = BrickWall(None, (64, 64, 64), bricks)
                wall = wall.realign_to_new_grid(mesh_grid)

                brick_meshes = wall.bricks.map(create_brick_mesh)
                consolidated_brick_meshes = brick_meshes.repartition(1)
                combined_mesh = delayed(create_combined_mesh)(
                    consolidated_brick_meshes)

                def write_mesh(mesh):
                    output_dir = options["output-directory"]
                    fmt = options["format"]
                    output_path = f'{output_dir}/{body}.{fmt}'
                    mesh.serialize(output_path)
                    return (body, len(mesh.vertices_zyx), 'success', '')

                # We hide the body ID in the task name, so that we can record it in pop_result
                task = delayed(write_mesh)(combined_mesh,
                                           dask_key_name=f'{body}')
                result_futures.add(self.client.compute(task))
                in_flight += 1

                assert in_flight <= options["concurrent-bodies"]
                while in_flight == options["concurrent-bodies"]:
                    body, vertices, result, msg = pop_result()
                    if result == "error":
                        logger.warning(
                            f"Body {body}: Failed to generate mesh: {msg}")
                    results.append((body, vertices, result, msg))

            # Flush the last batch of tasks
            while in_flight > 0:
                body, vertices, result, msg = pop_result()
                if result == "error":
                    logger.warning(
                        f"Body {body}: Failed to generate mesh: {msg}")
                results.append((body, vertices, result, msg))
        finally:
            stats_df = pd.DataFrame(
                results, columns=['body', 'vertices', 'result', 'msg'])
            stats_df.to_csv('mesh-stats.csv', index=False, header=True)

            failed_df = stats_df.query("result != 'success'")
            if len(failed_df) > 0:
                logger.warning(
                    f"Failed to create meshes for {len(failed_df)} bodies.  See mesh-stats.csv"
                )