Ejemplo n.º 1
0
def plot_predictions(input_dir, basename, patch_info_file, patch_size,
                     outputfname, annotations, compression_factor, alpha,
                     segmentation, n_segmentation_classes, custom_segmentation,
                     annotation_col, scaling_factor, tif_file):
    """Overlays classification, regression and segmentation patch level predictions on top of whole slide image."""
    dask_arr_dict = {
        os.path.basename(f).split('.zarr')[0]: da.from_zarr(f)
        for f in glob.glob(os.path.join(input_dir, '*.zarr'))
        if os.path.basename(f).split('.zarr')[0] == basename
    }
    pred_plotter = PredictionPlotter(
        dask_arr_dict,
        patch_info_file,
        compression_factor=compression_factor,
        alpha=alpha,
        patch_size=patch_size,
        no_db=False,
        plot_annotation=annotations,
        segmentation=segmentation,
        n_segmentation_classes=n_segmentation_classes,
        input_dir=input_dir,
        annotation_col=annotation_col,
        scaling_factor=scaling_factor)
    if custom_segmentation:
        pred_plotter.add_custom_segmentation(basename, custom_segmentation)
    img = pred_plotter.generate_image(basename)
    pred_plotter.output_image(img, outputfname, tif_file)
Ejemplo n.º 2
0
def read_zarr_dataset(path):
    """Read a zarr dataset, including an array or a group of arrays.

    Parameters
    --------
    path : str
        Path to directory ending in '.zarr'. Path can contain either an array
        or a group of arrays in the case of pyramid data.
    Returns
    -------
    image : array-like
        Array or list of arrays
    shape : tuple
        Shape of array or first array in list
    """
    if os.path.exists(os.path.join(path, '.zarray')):
        # load zarr array
        image = da.from_zarr(path)
        shape = image.shape
    elif os.path.exists(os.path.join(path, '.zgroup')):
        # else load zarr all arrays inside file, useful for pyramid data
        image = []
        for subpath in sorted(os.listdir(path)):
            if not subpath.startswith('.'):
                image.append(read_zarr_dataset(os.path.join(path, subpath))[0])
        shape = image[0].shape
    else:
        raise ValueError(f"Not a zarr dataset or group: {path}")
    return image, shape
Ejemplo n.º 3
0
def load_omero_zarr(path):
    zarr_path = path.endswith("/") and path or f"{path}/"
    attrs_path = zarr_path + ".zattrs"
    root_attrs = requests.get(attrs_path).json()

    resolutions = ["0"]  # TODO: could be first alphanumeric dataset on err
    try:
        print('root_attrs', root_attrs)
        if 'multiscales' in root_attrs:
            datasets = root_attrs['multiscales'][0]['datasets']
            resolutions = [d['path'] for d in datasets]
        print('resolutions', resolutions)
    except Exception as e:
        raise e

    pyramid = []
    for resolution in resolutions:
        # data.shape is (t, c, z, y, x) by convention
        data = da.from_zarr(f"{zarr_path}{resolution}")
        chunk_sizes = [
            str(c[0]) + (" (+ %s)" % c[-1] if c[-1] != c[0] else '')
            for c in data.chunks
        ]
        print('resolution', resolution, 'shape (t, c, z, y, x)', data.shape,
              'chunks', chunk_sizes, 'dtype', data.dtype)
        pyramid.append(data)

    metadata = load_omero_metadata(zarr_path)

    return (pyramid, {'channel_axis': 1, **metadata})
Ejemplo n.º 4
0
        def calc_random(view):
            import dask.array as da
            import numpy as np
            from glue.utils import view_shape
            x = da.from_zarr('/mnt/cephfs/smltar_numpyarr/zarr_data_full')
            z = x[12]
            # x1 = x[120]
            # x2 = x[121]
            # x3 = x[122]
            # x4 = x[123]

            # h = [x1,x2]
            # l = [x3,x4]

            z = z.compute()
            z = z.astype(int)
            # z = x[100:130]
            # m = x[1000:1030]
            # n = x[1500:1530]
            # p = x[1700:1730]
            # sum = (y + z - m + p) * n
            # print(sum)
            # fu = client.compute(sum)
            # #print(r.result())
            # re = fu.result()
            # result = np.random.random(view_shape((64,64,64), view))
            return z
Ejemplo n.º 5
0
def test_rechunk_dask_array(tmp_path, shape, source_chunks, dtype,
                            target_chunks, max_mem):

    ### Create source array ###
    source_array = dsa.ones(shape, chunks=source_chunks, dtype=dtype)

    ### Create targets ###
    target_store = str(tmp_path / "target.zarr")
    temp_store = str(tmp_path / "temp.zarr")

    rechunked = api.rechunk(source_array,
                            target_chunks,
                            max_mem,
                            target_store,
                            temp_store=temp_store)
    assert isinstance(rechunked, api.Rechunked)

    target_array = zarr.open(target_store)

    assert target_array.chunks == tuple(target_chunks)

    result = rechunked.execute()
    assert isinstance(result, zarr.Array)
    a_tar = dsa.from_zarr(target_array)
    assert dsa.equal(a_tar, 1).all().compute()
Ejemplo n.º 6
0
def test_rechunk_group(tmp_path):
    store_source = str(tmp_path / "source.zarr")
    group = zarr.group(store_source)
    group.attrs["foo"] = "bar"
    # 800 byte chunks
    a = group.ones("a", shape=(5, 10, 20), chunks=(1, 10, 20), dtype="f4")
    a.attrs["foo"] = "bar"
    b = group.ones("b", shape=(20, ), chunks=(10, ), dtype="f4")
    b.attrs["foo"] = "bar"

    target_store = str(tmp_path / "target.zarr")
    temp_store = str(tmp_path / "temp.zarr")

    max_mem = 1600  # should force a two-step plan for a
    target_chunks = {"a": (5, 10, 4), "b": (20, )}

    delayed = api.rechunk(group,
                          target_chunks,
                          max_mem,
                          target_store,
                          temp_store=temp_store)

    target_group = zarr.open(target_store)
    assert "a" in target_group
    assert "b" in target_group
    assert dict(group.attrs) == dict(target_group.attrs)

    dask.compute(delayed)
    for aname in target_chunks:
        a_tar = dsa.from_zarr(target_group[aname])
        assert dsa.equal(a_tar, 1).all().compute()
Ejemplo n.º 7
0
 def _get_masks(self):
     masks = {}
     zgroup = self.store.open_mask_group()
     for point in ['c', 'w', 's']:
         mask_faces = dsa.from_zarr(zgroup['mask_' + point]).astype('bool')
         masks[point] = _faces_to_facets(mask_faces)
     return masks
Ejemplo n.º 8
0
        def load_data(statistic, axis):
            import dask.array as da
            import numpy as np
            from glue.utils import view_shape
            x = da.from_zarr('/mnt/cephfs/zarr_data_full')
            f = 1500
            scale = 2

            lh = []
            for k in range(scale):
                lc = []
                for i in range(scale):
                    lr = []
                    for j in range(scale):
                        lr.append(x[f % 3500])
                        f = f + 1
                    lc.append(da.concatenate(lr))
                lh.append(da.concatenate(lc, 1))
            z = da.concatenate(lh, 2)

            if statistic == 'minimum':
                return da.min(z, axis).compute()
            elif statistic == 'maximum':
                return da.max(z, axis).compute()
            elif statistic == 'mean' or statistic == 'median':
                return da.mean(z, axis).compute()
            elif statistic == 'percentile':
                return percentile / 100
            elif statistic == 'sum':
                return da.sum(z.axis).compute()
            return 0
Ejemplo n.º 9
0
 def __init__(self,
              z: zarr.hierarchy,
              name: str,
              cell_data: MetaData,
              nthreads: int,
              min_cells_per_feature: int = 10):
     """
     Args:
         z (zarr.hierarchy): Zarr hierarchy to use.
         name (str): Name for assay.
         cell_data: Metadata for the cells.
         nthreads:
         min_cells_per_feature:
     """
     self.name = name
     self.z = z[self.name]
     self.cells = cell_data
     self.nthreads = nthreads
     self.rawData = daskarr.from_zarr(self.z['counts'], inline_array=True)
     self.feats = MetaData(self.z['featureData'])
     self.attrs = self.z.attrs
     if 'percentFeatures' not in self.attrs:
         self.attrs['percentFeatures'] = {}
     self.normMethod = norm_dummy
     self.sf = None
     self._ini_feature_props(min_cells_per_feature)
Ejemplo n.º 10
0
    def adata_dist(self, sc, request):
        # regular anndata except for X, which we replace on the next line
        a = ad.read_zarr(input_file)
        input_file_X = input_file + "/X"
        if request.param == "direct":
            a.X = zappy.direct.from_zarr(input_file_X)
            yield a
        elif request.param == "executor":
            with concurrent.futures.ThreadPoolExecutor(
                    max_workers=2) as executor:
                a.X = zappy.executor.from_zarr(executor, input_file_X)
                yield a
        elif request.param == "spark":
            a.X = zappy.spark.from_zarr(sc, input_file_X)
            yield a
        elif request.param == "dask":
            a.X = da.from_zarr(input_file_X)
            yield a
        elif request.param == "pywren":
            import s3fs.mapping

            s3 = s3fs.S3FileSystem()
            input_file_X = s3fs.mapping.S3Map(
                "sc-tom-test-data/10x-10k-subset.zarr/X", s3=s3)
            executor = zappy.executor.PywrenExecutor()
            a.X = zappy.executor.from_zarr(executor, input_file_X)
            yield a
Ejemplo n.º 11
0
        def load_ceph_data(view):
            import dask.array as da
            import numpy as np
            from glue.utils import view_shape
            x = da.from_zarr('/mnt/cephfs/zarr_data_full')
            f = 1500
            scale = 2

            #Construct the data graph. No computations involve for now.
            #Only access part of the data for the purpose of demo.
            lh = []
            for k in range(scale):
                lc = []
                for i in range(scale):
                    lr = []
                    for j in range(scale):
                        lr.append(x[f % 3500])
                        f = f + 1
                    lc.append(da.concatenate(lr))
                lh.append(da.concatenate(lc, 1))
            z = da.concatenate(lh, 2)

            if view != None:
                z = z[view]

            #fire the actual computation
            z = z.compute()
            return z
Ejemplo n.º 12
0
 def multi_zarr():
     path = 'https://s3.embassy.ebi.ac.uk/idr/zarr/v0.1/9822151.zarr'
     resolutions = [
         da.from_zarr(path, component=str(i))[0, 0, 0]
         for i in list(range(11))
     ]
     return napari.view_image(resolutions)
Ejemplo n.º 13
0
    def __init__(self,
                 z: zarr.hierarchy,
                 name: str,
                 cell_data: MetaData,
                 nthreads: int,
                 min_cells_per_feature: int = 10):
        """

        Args:
            z:
            name:
            cell_data:
            nthreads:
            min_cells_per_feature:
        """
        self.name = name
        self.z = z[self.name]
        self.cells = cell_data
        self.nthreads = nthreads
        self.rawData = daskarr.from_zarr(self.z['counts'])
        self.feats = MetaData(self.z['featureData'])
        self.attrs = self.z.attrs
        if 'percentFeatures' not in self.attrs:
            self.attrs['percentFeatures'] = {}
        self.normMethod = norm_dummy
        self.sf = None
        self._ini_feature_props(min_cells_per_feature)
Ejemplo n.º 14
0
def plot_image_umap_embeddings(input_dir,
                               embeddings_file,
                               basename,
                               outputfilename,
                               mpl_scatter,
                               remove_background_annotation,
                               max_background_area,
                               zoom,
                               n_neighbors,
                               sort_col='',
                               sort_mode='asc'):
    """Plots a UMAP embedding with each point as its corresponding patch image."""
    from pathflowai.visualize import plot_umap_images
    dask_arr_dict = {
        os.path.basename(f).split('.zarr')[0]: da.from_zarr(f)
        for f in glob.glob(os.path.join(input_dir, '*.zarr'))
        if (not basename) or (
            os.path.basename(f).split('.zarr')[0] == basename)
    }
    plot_umap_images(dask_arr_dict,
                     embeddings_file,
                     ID=basename,
                     cval=1.,
                     image_res=300.,
                     outputfname=outputfilename,
                     mpl_scatter=mpl_scatter,
                     remove_background_annotation=remove_background_annotation,
                     max_background_area=max_background_area,
                     zoom=zoom,
                     n_neighbors=n_neighbors,
                     sort_col=sort_col,
                     sort_mode=sort_mode)
Ejemplo n.º 15
0
def _view_from_paths(paths, scale):
    """
    View labels and image volumes from a list of paths
    """
    with napari.gui_qt():
        viewer = napari.Viewer()
        for path in paths:
            name = Path(path).stem
            array = da.from_zarr(path)
            if path.find('_labels.zarr') != -1:
                viewer.add_labels(
                                  array, 
                                  name=name, 
                                  scale=scale, 
                                  blending='additive', 
                                  visible=False
                                  )
            else:
                viewer.add_image(
                                 array, 
                                 name=name, 
                                 scale=scale, 
                                 blending='additive', 
                                 visible=False
                                 )
Ejemplo n.º 16
0
    def _update_viewer(self,display_data):

        # clean up viewer
        self.viewer.layers.clear()

        # channel names and colormaps to match control software
        channel_names = ['405nm','488nm','561nm','635nm','730nm']
        colormaps = ['bop purple','bop blue','bop orange','red','grey']

        active_channel_names=[]
        active_colormaps=[]

        dataset = da.from_zarr(zarr.open(self.dataset_zarr,mode='r'))

        # iterate through active channels and populate viewer
        for channel in self.channels_in_data:
            active_channel_names.append(channel_names[channel])
            active_colormaps.append(colormaps[channel])
        self.viewer.add_image(
            dataset, 
            channel_axis=1, 
            name=active_channel_names, 
            scale = self.scale,
            blending='additive', 
            colormap=active_colormaps)
        self.viewer.scale_bar.visible = True
        self.viewer.scale_bar.unit = 'um'
Ejemplo n.º 17
0
def _add_field_to_dataset(
    category: str,
    key: str,
    vcfzarr_key: str,
    variable_name: str,
    dims: List[str],
    field_def: Dict[str, Any],
    vcfzarr: zarr.Array,
    ds: xr.Dataset,
) -> None:
    if "ID" not in vcfzarr[vcfzarr_key].attrs:
        # only convert fields that were defined in the original VCF
        return
    vcf_number = field_def.get("Number", vcfzarr[vcfzarr_key].attrs["Number"])
    dimension, _ = vcf_number_to_dimension_and_size(
        # ploidy and max_alt_alleles are not relevant since size is not used here
        vcf_number,
        category,
        key,
        field_def,
        ploidy=2,
        max_alt_alleles=0,
    )
    if dimension is not None:
        dims.append(dimension)
    array = da.from_zarr(vcfzarr[vcfzarr_key])
    ds[variable_name] = (dims, array)
    if "Description" in vcfzarr[vcfzarr_key].attrs:
        description = vcfzarr[vcfzarr_key].attrs["Description"]
        if len(description) > 0:
            ds[variable_name].attrs["comment"] = description
Ejemplo n.º 18
0
        def func(uri, shape, dtype):
            group = self.handle[uri]

            # multi-scale?
            if ZarrDataset.is_multiscales(group):
                name = os.path.basename(uri)
                for group_attrs in group.attrs["multiscales"]:
                    # find the corresponding attributes
                    if group_attrs["name"] == name:
                        # get path for requested level
                        level = group_attrs["datasets"][self.level]["path"]
                        break
                else:
                    raise RuntimeError(
                        f'corrupted multiscale dataset, unable to find path for "{name}" (level: {self.level})'
                    )
            else:
                level = ""

            # build final path
            # NOTE "zarr.Group + sub-path" does not function properly, use "str + full
            # path" instead
            path = "/".join([uri, level])

            # zarr array contains shape, dtype and decompression info
            return da.from_zarr(self.root_dir, path)
Ejemplo n.º 19
0
    def load_neighbour_info(self, cache_dir, mask=None, **kwargs):
        """Read index arrays from either the in-memory or disk cache."""
        mask_name = getattr(mask, 'name', None)
        cached = {}
        self._check_numpy_cache(cache_dir, mask=mask_name, **kwargs)

        filename = self._create_cache_filename(cache_dir, prefix='nn_lut-',
                                               mask=mask_name, **kwargs)
        for idx_name in NN_COORDINATES.keys():
            if mask_name in self._index_caches:
                cached[idx_name] = self._apply_cached_index(
                    self._index_caches[mask_name][idx_name], idx_name)
            elif cache_dir:
                try:
                    cache = da.from_zarr(filename, idx_name)
                    if idx_name == 'valid_input_index':
                        # valid input index array needs to be boolean
                        cache = cache.astype(np.bool)
                except ValueError:
                    raise IOError
                cache = self._apply_cached_index(cache, idx_name)
                cached[idx_name] = cache
            else:
                raise IOError
        self._index_caches[mask_name] = cached
def demix_cells(save_root, dt, is_skip=True, dask_tmp=None, memory_limit=0):
    '''
      1. local pca denoise
      2. cell segmentation
    '''
    cluster, client = fdask.setup_workers(is_local=True, dask_tmp=dask_tmp, memory_limit=memory_limit)
    print_client_links(cluster)
    Y_svd = da.from_zarr(f'{save_root}/detrend_data.zarr')
    Y_svd = Y_svd[:, :, :, ::dt]
    mask = da.from_zarr(f'{save_root}/Y_d_max.zarr')
    if not os.path.exists(f'{save_root}/sup_demix_rlt/'):
        os.mkdir(f'{save_root}/sup_demix_rlt/')
    da.map_blocks(demix_blocks, Y_svd.astype('float'), mask.astype('float'), chunks=(1, 1, 1, 1), dtype='int8', save_folder=save_root, is_skip=is_skip).compute()
    fdask.terminate_workers(cluster, client)
    time.sleep(100)
    return None
Ejemplo n.º 21
0
    def set_norm_factors(self, data_group, fold_group, overwrite=False):

        # Get Zarr arrays
        train_indexes = fold_group['train'][:]
        X = da.from_zarr(data_group['X'])
        norm_shape = X.shape[1:]

        # Create normalization data Zarr arrays
        norm_group = fold_group.require_group('norm_data')
        norm_group.require_dataset('s1', shape=norm_shape, dtype='float32', chunks=None)
        norm_group.require_dataset('s2', shape=norm_shape, dtype='float32', chunks=None)
        norm_group.require_dataset('mean', shape=norm_shape, dtype='float32', chunks=None)
        norm_group.require_dataset('std', shape=norm_shape, dtype='float32', chunks=None)

        # Stop processing if already done AND we don't want to overwrite the dataset
        if (norm_group['s1'].nchunks == norm_group['s1'].nchunks_initialized) and not overwrite:
            return

        # Compute normalization factors
        fold_num = pathlib.PurePath(fold_group.name).name[-1]
        print(f'Computing the normalization factors for the cross-validation fold #{fold_num}.\nThis may take some time...')

        # Compute sum and squared sum
        s1 = X[train_indexes,].sum(axis=0)
        s2 = (X[train_indexes,] ** 2).sum(axis=0)
        S = da.stack([s1, s2], axis=0).compute()
        s1 = S[0,]
        s2 = S[1,]
        n = train_indexes.size

        # Fill Zarr arrays with the normalization factors
        norm_group['s1'][:] = s1
        norm_group['s2'][:] = s2
        norm_group['mean'][:] = s1 / n
        norm_group['std'][:] = np.sqrt((n * s2 - (s1 * s1)) / (n * (n - 1)))
def check_demix_cells_layer(save_root, nlayer, nsplit = (10, 16), mask=None):
    import matplotlib.pyplot as plt
    Y_d_ave = da.from_zarr(f'{save_root}/motion_corrected_data.zarr')
    _, xdim, ydim, _ = Y_d_ave.shape
    _, x_, y_, _ = Y_d_ave.chunksize
    A_mat = np.zeros((xdim, ydim))
    for nx in range(nsplit[0]):
        for ny in range(nsplit[1]):
            try:
                A_ = load_A_matrix(save_root=save_root, ext='', block_id=(nlayer, nx, ny, 0), min_size=0)
                A_[A_<A_.max(axis=0, keepdims=True)*0.5]=0
                A_ = A_.reshape((x_, y_, -1), order="F")
                mask_ = mask[x_*nx:x_*(nx+1), y_*ny:y_*(ny+1)]
                A_[~mask_]=0
                A_ = A_[:, :, (A_>0).sum(axis=(0,1))>10]
                A_mat[x_*nx:x_*(nx+1), y_*ny:y_*(ny+1)] = A_.sum(axis=-1)
            except:
                pass

    plt.figure(figsize=(8, 8))
    plt.imshow(A_mat)
    plt.title(f'Components {nlayer}')
    plt.axis('off')
    plt.show()
    return None
Ejemplo n.º 23
0
def get_complete_non_nan_sites(z, slices):
    fill_tup = (slice(0, 1, 1), ) * (z.ndim - 3)
    tup = (slices['lat'], slices['lon'], slices['prob']) + fill_tup
    
    if isinstance(z, da.core.Array):
        sliced_da = z[tup]
    else:
        sliced_da = da.from_zarr(z)[tup]

    arr = sliced_da.compute()
    complete_sliced_non_nan_coords_linear = np.where(
        (arr != -np.inf) & (~np.isnan(arr))
    )[:3]
    
    complete_sliced_non_nan_coords_tuples = _convert_sliced_linear_coords_to_sliced_coords_tuples(
        *complete_sliced_non_nan_coords_linear
    )
    
    complete_non_nan_coords_tuples = _convert_sliced_linear_coords_to_global_coords_tuples(
        *complete_sliced_non_nan_coords_linear, # *(lat, lon, prob)
        slices
    )
    nr_complete_sites = len(complete_non_nan_coords_tuples)

    return nr_complete_sites, complete_non_nan_coords_tuples, complete_sliced_non_nan_coords_tuples
Ejemplo n.º 24
0
    def save_results(self,zarr_path,chunks):
        logger.info("saving validation results...")
        if path.exists(zarr_path):
            if_overwrite = ask_ok_or_not(self.viewer,"Validation file already exists. Overwrite?")
            if not if_overwrite:
                return
        logger.info("saving mask ...")
        # to avoid IO from/to the same array, save to a temp array and then rename
        self.label_layer.data.rechunk(chunks).to_zarr(zarr_path,"mask_tmp",overwrite=True)
        zarr_file=zarr.open(zarr_path,"a")
        del zarr_file["mask"]
        zarr_file.store.rename("mask_tmp","mask")

        logger.info("saving segments...")
        _df_segments
        zarr_file["df_segments"]=self.df_segments.reset_index().astype(int).values
        logger.info("saving divisions...")
        zarr_file["df_divisions"]=self.df_divisions.reset_index().astype(int).values
        mask_ds.attrs["df_divisions"]=self.df_divisions.reset_index().to_dict()
        logger.info("saving others...")
        mask_ds.attrs["finalized_segment_ids"]=list(map(int,self.finalized_segment_ids))
        mask_ds.attrs["candidate_segment_ids"]=list(map(int,self.candidate_segment_ids))
        mask_ds.attrs["target_Ts"]=list(map(int,self.target_Ts))

        logger.info("reading data ...")
        self.label_layer.data = da.from_zarr(mask_ds).persist()
        logger.info("saving validation results finished")
Ejemplo n.º 25
0
 def __call__(self, *args, cache_dir: Optional[str] = None) -> Any:
     """Call the decorated function."""
     new_args = self._sanitize_args_func(
         *args) if self._sanitize_args_func is not None else args
     arg_hash = _hash_args(*new_args)
     should_cache, cache_dir = self._get_should_cache_and_cache_dir(
         new_args, cache_dir)
     zarr_fn = self._zarr_pattern(arg_hash)
     zarr_format = os.path.join(cache_dir, zarr_fn)
     zarr_paths = glob(zarr_format.format("*"))
     if not should_cache or not zarr_paths:
         # use sanitized arguments if we are caching, otherwise use original arguments
         args = new_args if should_cache else args
         res = self._func(*args)
         if should_cache and not zarr_paths:
             self._cache_results(res, zarr_format)
     # if we did any caching, let's load from the zarr files
     if should_cache:
         # re-calculate the cached paths
         zarr_paths = glob(zarr_format.format("*"))
         if not zarr_paths:
             raise RuntimeError(
                 "Data was cached to disk but no files were found")
         res = tuple(da.from_zarr(zarr_path) for zarr_path in zarr_paths)
     return res
Ejemplo n.º 26
0
def load_preprocessed_img(img_file):
    if img_file.endswith(
            '.zarr') and not os.path.exists(f"{img_file}/.zarray"):
        img_file = img_file.replace(".zarr", ".npy")
    return npy2da(img_file) if (
        img_file.endswith('.npy')
        or img_file.endswith('.h5')) else da.from_zarr(img_file)
Ejemplo n.º 27
0
def tifffile_dask_backend(image_filepath,
                          largest_series,
                          preprocessing,
                          force_rgb=None):
    """
    Read image with tifffile and use dask to read data into memory

    Parameters
    ----------
    image_filepath: str
        path to the image file
    largest_series: int
        index of the largest series in the image
    preprocessing:
        whether to do some read-time pre-processing
        - greyscale conversion (at the tile level)
        - read individual or range of channels (at the tile level)

    Returns
    -------
    image: sitk.Image
        image ready for other registration pre-processing

    """
    print("using dask backend")
    zarr_series = imread(image_filepath, aszarr=True, series=largest_series)
    zarr_store = zarr.open(zarr_series)

    dask_im = da.squeeze(da.from_zarr(zarr_get_base_pyr_layer(zarr_store)))
    if force_rgb is None:
        is_rgb = guess_rgb(dask_im.shape)
        is_interleaved = True if is_rgb else False
    elif force_rgb is True and guess_rgb(dask_im.shape) is False:
        is_rgb = True
        is_interleaved = False

    if is_rgb:
        if preprocessing is not None:
            image = grayscale(dask_im, is_interleaved=is_interleaved).compute()

            image = sitk.GetImageFromArray(image)
        else:
            image = dask_im.compute()
            if is_interleaved is False:
                image = np.rollaxis(image, 0, 3)
            image = sitk.GetImageFromArray(image, isVector=True)

    elif len(dask_im.shape) == 2:
        image = sitk.GetImageFromArray(dask_im.compute())

    else:
        if preprocessing is not None:
            if (preprocessing.get("ch_indices") is not None
                    and len(dask_im.shape) > 2):
                chs = np.asarray(preprocessing.get('ch_indices'))
                dask_im = dask_im[chs, :, :]

        image = sitk.GetImageFromArray(np.squeeze(dask_im.compute()))

    return image
Ejemplo n.º 28
0
def plot_image_(image_file, compression_factor=2., test_image_name='test.png'):
    """Plots entire SVS/other image.

	Parameters
	----------
	image_file:str
		Image file.
	compression_factor:float
		Amount to shrink each dimension of image.
	test_image_name:str
		Output image file.

	"""
    from pathflowai.utils import svs2dask_array, npy2da
    import cv2
    if image_file.endswith('.zarr'):
        arr = da.from_zarr(image_file)
    else:
        arr = svs2dask_array(
            image_file,
            tile_size=1000,
            overlap=0,
            remove_last=True,
            allow_unknown_chunksizes=False) if (
                not image_file.endswith('.npy')) else npy2da(image_file)
    arr2 = to_pil(
        cv2.resize(arr.compute(),
                   dsize=tuple((np.array(arr.shape[:2]) /
                                compression_factor).astype(int).tolist()),
                   interpolation=cv2.INTER_CUBIC))
    arr2.save(test_image_name)
Ejemplo n.º 29
0
    def load_ome_zarr(self):

        resolutions = ["0"]  # TODO: could be first alphanumeric dataset on err
        try:
            print('root_attrs', self.root_attrs)
            if 'multiscales' in self.root_attrs:
                datasets = self.root_attrs['multiscales'][0]['datasets']
                resolutions = [d['path'] for d in datasets]
            print('resolutions', resolutions)
        except Exception as e:
            raise e

        pyramid = []
        for resolution in resolutions:
            # data.shape is (t, c, z, y, x) by convention
            data = da.from_zarr(f"{self.zarr_path}{resolution}")
            chunk_sizes = [
                str(c[0]) + (" (+ %s)" % c[-1] if c[-1] != c[0] else '')
                for c in data.chunks
            ]
            print('resolution', resolution, 'shape (t, c, z, y, x)',
                  data.shape, 'chunks', chunk_sizes, 'dtype', data.dtype)
            pyramid.append(data)

        if len(pyramid) == 1:
            pyramid = pyramid[0]

        metadata = self.load_omero_metadata(data.shape[1])
        return (pyramid, {'channel_axis': 1, **metadata})
Ejemplo n.º 30
0
def tifffile_dask_backend(image_filepath,
                          largest_series,
                          preprocessing,
                          force_rgb=None):
    """
    Read image with tifffile and use dask to read data into memory

    Parameters
    ----------
    image_filepath: str
        path to the image file
    largest_series: int
        index of the largest series in the image
    preprocessing:
        whether to do some read-time pre-processing
        - greyscale conversion (at the tile level)
        - read individual or range of channels (at the tile level)

    Returns
    -------
    image: sitk.Image
        image ready for other registration pre-processing

    """
    print("using dask backend")
    zarr_series = imread(image_filepath, aszarr=True, series=largest_series)
    zarr_store = zarr.open(zarr_series)
    dask_im = da.squeeze(da.from_zarr(zarr_get_base_pyr_layer(zarr_store)))
    return read_preprocess_array(dask_im,
                                 preprocessing=preprocessing,
                                 force_rgb=force_rgb)