Exemple #1
0
def test_bw_corrcoef():
    # params perhaps less critical, can be parametrized in future
    arr_shape, block_shape = (4, 6), (2, 3)

    # straightforward implementation, takes np.ndarray
    def bruteforce_bw(arr1, arr2, block_shape):
        out_shape = [
            arr1.shape[ax] // block_shape[ax] for ax in range(arr1.ndim)
        ]
        out = np.empty(out_shape)
        for coord in itertools.product(*[range(d) for d in out_shape]):
            s = [
                slice(coord[ax] * block_shape[ax],
                      (coord[ax] + 1) * block_shape[ax])
                for ax in range(arr1.ndim)
            ]
            block1 = arr1[tuple(s)]
            block2 = arr2[tuple(s)]
            out[coord] = np.corrcoef(block1.flatten(), block2.flatten())[0, 1]
        return out

    # prepare identical data with different format
    arr1_npy = np.arange(np.prod(arr_shape)).reshape(arr_shape)
    arr2_npy = np.roll(arr1_npy, 1, 0)
    arr1_da = da.from_array(arr1_npy)
    arr2_da = da.from_array(arr2_npy)

    # reference result
    ref_npy = bruteforce_bw(arr1_npy, arr2_npy, block_shape)
    ref_da = da.from_array(ref_npy)

    # test result
    test_da = blockwise.bw_corrcoef(arr1_da, arr2_da, block_shape)

    assert da.allclose(ref_da, test_da)
Exemple #2
0
def test_bw_func(funcname):
    # params perhaps less critical, can be parametrized in future
    arr_shape, block_shape = (4, 6), (2, 3)

    # straightforward implementation, takes np.ndarray
    def bruteforce_bw(arr, block_shape):
        out_shape = [
            arr.shape[ax] // block_shape[ax] for ax in range(arr.ndim)
        ]
        out = np.empty(out_shape)
        for coord in itertools.product(*[range(d) for d in out_shape]):
            s = [
                slice(coord[ax] * block_shape[ax],
                      (coord[ax] + 1) * block_shape[ax])
                for ax in range(arr.ndim)
            ]
            block = arr[tuple(s)]
            out[coord] = eval("np." + funcname)(block)
        return out

    # get function name
    bw_func = getattr(blockwise, "bw_" + funcname)

    # prepare identical data with different format
    arr_npy = np.arange(np.prod(arr_shape)).reshape(arr_shape)
    arr_da = da.from_array(arr_npy)

    # reference result
    ref_npy = bruteforce_bw(arr_npy, block_shape)
    ref_da = da.from_array(ref_npy)

    # test result
    test_da = bw_func(arr_da, block_shape)

    assert da.allclose(ref_da, test_da)
Exemple #3
0
def test_repeat_block():
    arr = da.arange(4).reshape((2, 2))
    ref = da.from_array(np.array([
        [0, 0, 1, 1],
    ] * 3 + [
        [2, 2, 3, 3],
    ] * 3))
    test = blockwise.repeat_block(arr, (3, 2))
    assert da.allclose(ref, test)
Exemple #4
0
def test_allclose():
    n_a = np.array([0, np.nan, 1, 1.5])
    n_b = np.array([1e-9, np.nan, 1, 2])

    d_a = da.from_array(n_a, chunks=(2, ))
    d_b = da.from_array(n_b, chunks=(2, ))

    n_r = np.allclose(n_a, n_b, equal_nan=True)
    d_r = da.allclose(d_a, d_b, equal_nan=True)

    assert_eq(np.array(n_r)[()], d_r)
Exemple #5
0
def test_allclose():
    n_a = np.array([0, np.nan, 1, 1.5])
    n_b = np.array([1e-9, np.nan, 1, 2])

    d_a = da.from_array(n_a, chunks=(2,))
    d_b = da.from_array(n_b, chunks=(2,))

    n_r = np.allclose(n_a, n_b, equal_nan=True)
    d_r = da.allclose(d_a, d_b, equal_nan=True)

    assert_eq(np.array(n_r)[()], d_r)
def store_correct(split_filepath, arr_list, logical_chunks_shape):
    print("Testing", len(arr_list), "matches...")
    with h5py.File(split_filepath, 'r') as f:
        for i, a in enumerate(arr_list):
            stored_a = da.from_array(f['/data' + str(i)])
            print("split shape:", stored_a.shape)

            stored_a.rechunk(chunks=logical_chunks_shape)
            print("split rechunked to:", stored_a.shape)
            print("will be compared to : ", a.shape)

            print("Testing all close...")
            test = da.allclose(stored_a, a)
            assert test.compute()
            print("Passed.")
Exemple #7
0
    def store_correct():
        """ Compare the real chunks to the splits to see if correctly splitted. 
        """
        logger.info("Testing %s matches...", len(arr_list))
        with h5py.File(split_filepath, 'r') as f:
            for i, a in enumerate(arr_list):
                stored_a = da.from_array(f['/data' + str(i)])
                # logger.info("split shape: %s", stored_a.shape)

                stored_a.rechunk(chunks=shape_to_test)
                # logger.info("split rechunked to: %s", stored_a.shape)
                # logger.info("will be compared to : %s ", a.shape)
                # logger.info("Testing all close...")
                test = da.allclose(stored_a, a)
                disable_clustering(
                )  # TODO: remove this, make it work even for all close
                assert test.compute()
        logger.info("Passed.\n")
Exemple #8
0
def check_outputs():
    # sanity check
    outfiles = list()
    for fpath in glob.glob(
            "[0-9].hdf5"):  # remove split files from previous tests
        print(f'Filename: {fpath}')
        with h5py.File(fpath, 'r') as f:
            inspect_h5py_file(f)

    # prepare ground truth for verification
    arrays_expected = dict()
    outfiles_partititon = get_blocks_shape((1, 120, 120), O)
    outfiles_volumes = get_named_volumes(outfiles_partititon, O)
    for outfilekey, volume in outfiles_volumes.items():
        slices = convert_Volume_to_slices(volume)
        arrays_expected[outfilekey] = reconstructed_array[slices[0], slices[1],
                                                          slices[2]]

    # verify
    for fpath in glob.glob("[0-9].hdf5"):
        outputfile_index = int(fpath.split('.')[0])
        print(f'Output file index: ', outputfile_index)

        array_stored = get_dask_array_from_hdf5(fpath,
                                                '/data',
                                                logic_cs="dataset_shape")
        arr_expected = arrays_expected[outputfile_index]
        print("equal:", da.allclose(array_stored, arr_expected).compute())
        print(
            "stored:", array_stored[slice(0, 1, None),
                                    slice(0, 1, None),
                                    slice(0, 10, None)].compute())
        print(
            "expected", arr_expected[slice(0, 1, None),
                                     slice(0, 1, None),
                                     slice(0, 10, None)].compute())
Exemple #9
0
def test_trim():
    arr_shape, block_shape = (6, 5), (3, 2)
    arr = da.arange(np.prod(arr_shape)).reshape(arr_shape)
    test = blockwise.trim(arr, block_shape)
    ref = arr[:, :4]
    assert da.allclose(ref, test)
Exemple #10
0
def test_shift(arr_shape, fill_value):
    # less critical params
    seed = 42
    low, high = 0, 2
    # start test
    da.random.seed(seed)
    arr = da.random.randint(low=low, high=high, size=arr_shape)
    for ax in range(arr.ndim):
        d = arr_shape[ax]
        filled_slice = [
            ":",
        ] * arr.ndim
        kept_slice = [
            ":",
        ] * arr.ndim
        for num in range(-(d - 1), d):
            shifted = blockwise.shift(arr, num, ax, fill_value=fill_value)

            filled_slice = [
                ":",
            ] * arr.ndim
            kept_slice = [
                ":",
            ] * arr.ndim
            new_slice = [
                ":",
            ] * arr.ndim

            if num == 0:
                assert da.allclose(arr, shifted)
            elif num > 0:
                kept_slice[ax] = "0:{}".format(-num)
                new_slice[ax] = "{}:".format(num)

                ref = eval("arr[" + ", ".join(kept_slice) + "]")
                test = eval("shifted[" + ", ".join(new_slice) + "]")

                assert da.allclose(ref, test)

                filled_slice[ax] = "0:{}".format(num)
                filled_shape = arr_shape[:ax] + (num, ) + arr_shape[ax + 1:]

                ref = da.full(shape=filled_shape, fill_value=fill_value)
                test = eval("shifted[" + ", ".join(filled_slice) + "]")

                assert da.allclose(ref, test)
            else:
                kept_slice[ax] = "{}:".format(-num)
                new_slice[ax] = ":{}".format(num)

                ref = eval("arr[" + ", ".join(kept_slice) + "]")
                test = eval("shifted[" + ", ".join(new_slice) + "]")

                assert da.allclose(ref, test)

                filled_slice[ax] = "{}:".format(num)
                filled_shape = arr_shape[:ax] + (-num, ) + arr_shape[ax + 1:]

                ref = da.full(shape=filled_shape, fill_value=fill_value)
                test = eval("shifted[" + ", ".join(filled_slice) + "]")

                assert da.allclose(ref, test)
Exemple #11
0
def new_grid_mapping_from_coords(
    x_coords: xr.DataArray,
    y_coords: xr.DataArray,
    crs: Union[str, pyproj.crs.CRS],
    *,
    tile_size: Union[int, Tuple[int, int]] = None,
    tolerance: float = DEFAULT_TOLERANCE,
) -> GridMapping:
    crs = _normalize_crs(crs)
    assert_instance(x_coords, xr.DataArray, name='x_coords')
    assert_instance(y_coords, xr.DataArray, name='y_coords')
    assert_true(x_coords.ndim in (1, 2),
                'x_coords and y_coords must be either 1D or 2D arrays')
    assert_instance(tolerance, float, name='tolerance')
    assert_true(tolerance > 0.0, 'tolerance must be greater zero')

    if x_coords.name and y_coords.name:
        xy_var_names = str(x_coords.name), str(y_coords.name)
    else:
        xy_var_names = _default_xy_var_names(crs)

    tile_size = _normalize_int_pair(tile_size, default=None)
    is_lon_360 = None  # None means "not yet known"
    if crs.is_geographic:
        is_lon_360 = bool(np.any(x_coords > 180))

    x_res = 0
    y_res = 0

    if x_coords.ndim == 1:
        # We have 1D x,y coordinates
        cls = Coords1DGridMapping

        assert_true(x_coords.size >= 2 and y_coords.size >= 2,
                    'sizes of x_coords and y_coords 1D arrays must be >= 2')

        size = x_coords.size, y_coords.size

        x_dim, y_dim = x_coords.dims[0], y_coords.dims[0]

        x_diff = _abs_no_zero(x_coords.diff(dim=x_dim).values)
        y_diff = _abs_no_zero(y_coords.diff(dim=y_dim).values)

        if not is_lon_360 and crs.is_geographic:
            is_anti_meridian_crossed = np.any(np.nanmax(x_diff) > 180)
            if is_anti_meridian_crossed:
                x_coords = to_lon_360(x_coords)
                x_diff = _abs_no_zero(x_coords.diff(dim=x_dim))
                is_lon_360 = True

        x_res, y_res = x_diff[0], y_diff[0]
        x_diff_equal = np.allclose(x_diff, x_res, atol=tolerance)
        y_diff_equal = np.allclose(y_diff, y_res, atol=tolerance)
        is_regular = x_diff_equal and y_diff_equal
        if is_regular:
            x_res = round_to_fraction(x_res, 5, 0.25)
            y_res = round_to_fraction(y_res, 5, 0.25)
        else:
            x_res = round_to_fraction(float(np.nanmedian(x_diff)), 2, 0.5)
            y_res = round_to_fraction(float(np.nanmedian(y_diff)), 2, 0.5)

        if tile_size is None \
                and x_coords.chunks is not None \
                and y_coords.chunks is not None:
            tile_size = (max(0,
                             *x_coords.chunks[0]), max(0, *y_coords.chunks[0]))

        # Guess j axis direction
        is_j_axis_up = bool(y_coords[0] < y_coords[-1])

    else:
        # We have 2D x,y coordinates
        cls = Coords2DGridMapping

        assert_true(
            x_coords.shape == y_coords.shape, 'shapes of x_coords and y_coords'
            ' 2D arrays must be equal')
        assert_true(
            x_coords.dims == y_coords.dims,
            'dimensions of x_coords and y_coords'
            ' 2D arrays must be equal')

        y_dim, x_dim = x_coords.dims

        height, width = x_coords.shape
        size = width, height

        x = da.asarray(x_coords)
        y = da.asarray(y_coords)

        x_x_diff = _abs_no_nan(da.diff(x, axis=1))
        x_y_diff = _abs_no_nan(da.diff(x, axis=0))
        y_x_diff = _abs_no_nan(da.diff(y, axis=1))
        y_y_diff = _abs_no_nan(da.diff(y, axis=0))

        if not is_lon_360 and crs.is_geographic:
            is_anti_meridian_crossed = da.any(da.max(x_x_diff) > 180) \
                                       or da.any(da.max(x_y_diff) > 180)
            if is_anti_meridian_crossed:
                x_coords = to_lon_360(x_coords)
                x = da.asarray(x_coords)
                x_x_diff = _abs_no_nan(da.diff(x, axis=1))
                x_y_diff = _abs_no_nan(da.diff(x, axis=0))
                is_lon_360 = True

        is_regular = False

        if da.all(x_y_diff == 0) and da.all(y_x_diff == 0):
            x_res = x_x_diff[0, 0]
            y_res = y_y_diff[0, 0]
            is_regular = \
                da.allclose(x_x_diff[0, :], x_res, atol=tolerance) \
                and da.allclose(x_x_diff[-1, :], x_res, atol=tolerance) \
                and da.allclose(y_y_diff[:, 0], y_res, atol=tolerance) \
                and da.allclose(y_y_diff[:, -1], y_res, atol=tolerance)

        if not is_regular:
            # Let diff arrays have same shape as original by
            # doubling last rows and columns.
            x_x_diff_c = da.concatenate([x_x_diff, x_x_diff[:, -1:]], axis=1)
            y_x_diff_c = da.concatenate([y_x_diff, y_x_diff[:, -1:]], axis=1)
            x_y_diff_c = da.concatenate([x_y_diff, x_y_diff[-1:, :]], axis=0)
            y_y_diff_c = da.concatenate([y_y_diff, y_y_diff[-1:, :]], axis=0)
            # Find resolution via area
            x_abs_diff = da.sqrt(da.square(x_x_diff_c) + da.square(x_y_diff_c))
            y_abs_diff = da.sqrt(da.square(y_x_diff_c) + da.square(y_y_diff_c))
            if crs.is_geographic:
                # Convert degrees into meters
                x_abs_diff_r = da.radians(x_abs_diff)
                y_abs_diff_r = da.radians(y_abs_diff)
                x_abs_diff = _ER * da.cos(x_abs_diff_r) * y_abs_diff_r
                y_abs_diff = _ER * y_abs_diff_r
            xy_areas = (x_abs_diff * y_abs_diff).flatten()
            xy_areas = da.where(xy_areas > 0, xy_areas, np.nan)
            # Get indices of min and max area
            xy_area_index_min = da.nanargmin(xy_areas)
            xy_area_index_max = da.nanargmax(xy_areas)
            # Convert area to edge length
            xy_res_min = math.sqrt(xy_areas[xy_area_index_min])
            xy_res_max = math.sqrt(xy_areas[xy_area_index_max])
            # Empirically weight min more than max
            xy_res = 0.7 * xy_res_min + 0.3 * xy_res_max
            if crs.is_geographic:
                # Convert meters back into degrees
                # print(f'xy_res in meters: {xy_res}')
                xy_res = math.degrees(xy_res / _ER)
                # print(f'xy_res in degrees: {xy_res}')
            # Because this is an estimation, we can round to a nice number
            xy_res = round_to_fraction(xy_res, digits=1, resolution=0.5)
            x_res, y_res = float(xy_res), float(xy_res)

        if tile_size is None and x_coords.chunks is not None:
            j_chunks, i_chunks = x_coords.chunks
            tile_size = max(0, *i_chunks), max(0, *j_chunks)

        if tile_size is not None:
            tile_width, tile_height = tile_size
            x_coords = x_coords.chunk((tile_height, tile_width))
            y_coords = y_coords.chunk((tile_height, tile_width))

        # Guess j axis direction
        is_j_axis_up = np.all(y_coords[0, :] < y_coords[-1, :]) or None

    assert_true(x_res > 0 and y_res > 0,
                'internal error: x_res and y_res could not be determined',
                exception_type=RuntimeError)

    x_res, y_res = _to_int_or_float(x_res), _to_int_or_float(y_res)
    x_res_05, y_res_05 = x_res / 2, y_res / 2
    x_min = _to_int_or_float(x_coords.min() - x_res_05)
    y_min = _to_int_or_float(y_coords.min() - y_res_05)
    x_max = _to_int_or_float(x_coords.max() + x_res_05)
    y_max = _to_int_or_float(y_coords.max() + y_res_05)

    return cls(x_coords=x_coords,
               y_coords=y_coords,
               crs=crs,
               size=size,
               tile_size=tile_size,
               xy_bbox=(x_min, y_min, x_max, y_max),
               xy_res=(x_res, y_res),
               xy_var_names=xy_var_names,
               xy_dim_names=(str(x_dim), str(y_dim)),
               is_regular=is_regular,
               is_lon_360=is_lon_360,
               is_j_axis_up=is_j_axis_up)
Exemple #12
0
def copy_dataset(h5_orig_dset, h5_dest_grp, alias=None, verbose=False):
    """
    Copies the provided HDF5 dataset to the provided destination. This function
    is handy when needing to make copies of datasets to a different HDF5 file.
    Notes
    -----
    This function does NOT copy all linked objects such as ancillary
    datasets. Call `copy_linked_objects` to accomplish that goal.
    Parameters
    ----------
    h5_orig_dset : h5py.Dataset
    h5_dest_grp : h5py.Group or h5py.File object :
        Destination where the duplicate dataset will be created
    alias : str, optional. Default = name from `h5_orig_dset`:
        Name to be assigned to the copied dataset
    verbose : bool, optional. Default = False
        Whether or not to print logs to assist in debugging
    Returns
    -------
    """
    if not isinstance(h5_orig_dset, h5py.Dataset):
        raise TypeError("'h5_orig_dset' should be a h5py.Dataset object")
    if not isinstance(h5_dest_grp, (h5py.File, h5py.Group)):
        raise TypeError("'h5_dest_grp' should either be a h5py.File or "
                        "h5py.Group object")
    if alias is not None:
        validate_single_string_arg(alias, 'alias')
    else:
        alias = h5_orig_dset.name.split('/')[-1]

    if alias in h5_dest_grp.keys():
        if verbose:
            warn('{} already contains an object with the same name: {}'
                 ''.format(h5_dest_grp, alias))
        h5_new_dset = h5_dest_grp[alias]
        if not isinstance(h5_new_dset, h5py.Dataset):
            raise TypeError(
                '{} already contains an object: {} with the desired'
                ' name which is not a dataset'.format(h5_dest_grp,
                                                      h5_new_dset))

        da_source = lazy_load_array(h5_orig_dset)
        da_dest = lazy_load_array(h5_new_dset)

        if da_source.shape != da_dest.shape:
            raise ValueError('Existing dataset: {} has a different shape '
                             'compared to the original dataset: {}'
                             ''.format(h5_new_dset, h5_orig_dset))
        if not da.allclose(da_source, da_dest):
            raise ValueError('Existing dataset: {} has different contents'
                             'compared to the original dataset: {}'
                             ''.format(h5_new_dset, h5_orig_dset))
    else:

        kwargs = {
            'shape': h5_orig_dset.shape,
            'dtype': h5_orig_dset.dtype,
            'compression': h5_orig_dset.compression,
            'chunks': h5_orig_dset.chunks
        }
        if h5_orig_dset.file.driver == 'mpio':
            if kwargs.pop('compression', None) is not None:
                warn('This HDF5 file has been opened wth the '
                     '"mpio" communicator. mpi4py does not allow '
                     'creation of compressed datasets. Compression'
                     ' kwarg has been removed')
        if verbose:
            print('Creating new HDF5 dataset named: {} at: {} with'
                  ' kwargs: {}'.format(alias, h5_dest_grp, kwargs))
        h5_new_dset = h5_dest_grp.create_dataset(alias, **kwargs)
        if verbose:
            print('dask.array will copy data from source dataset '
                  'to new dataset')
        da.to_hdf5(h5_new_dset.file.filename,
                   {h5_new_dset.name: lazy_load_array(h5_orig_dset)})
    if verbose:
        print('Copying simple attributes of original dataset: {} to '
              'destination dataset: {}'.format(h5_orig_dset, h5_new_dset))

    copy_attributes(h5_orig_dset, h5_new_dset, skip_refs=True)
    # TODO: reinstate copy all region_refs()
    # copy_all_region_refs(h5_orig_dset, h5_new_dset)

    return h5_new_dset
Exemple #13
0
def compute_angles(traj, angle_indices, periodic=True, opt=True, **kwargs):
    """ Daskified version of mdtraj.compute_angles().

    This mimics py:method:`mdtraj.compute_angles()` but returns the answer
    as a py:class:`dask.array` object

    Parameters
    ----------
    traj : :py:class:`dask_traj.Trajectory`
        The trajectory to compute the angles for.
    angle_indices : array of shape(any, 3)
        The indices for which to compute an angle.
    periodic : bool
        Wether to use the periodc boundary during the calculation.
    opt : bool, default=True
        Use an optimized native library to calculate distances. MDTraj's
        optimized SSE angle calculation implementation is 10-20x faster than
        the (itself optimized) numpy implementation.

    Returns
    -------
    angles : dask.array, shape(n_frames, angle_indices)
        Dask array with the delayed calculated angle for each item in
        angle_indices for each frame.
    """

    xyz = traj.xyz
    length = len(xyz)
    atoms = len(angle_indices)
    triplets = ensure_type(
        angle_indices,
        dtype=np.int32,
        ndim=2,
        name="angle_indices",
        shape=(None, 3),
        warn_on_cast=False,
    )
    if not np.all(np.logical_and(triplets < traj.n_atoms, triplets >= 0)):
        raise ValueError("angle_indices must be between 0 and %d" %
                         traj.n_atoms)

    if len(triplets) == 0:
        return da.zeros((len(xyz), 0), dtype=np.float32)

    if periodic and traj._have_unitcell:
        box = ensure_type(
            traj.unitcell_vectors,
            dtype=np.float32,
            ndim=3,
            name="unitcell_vectors",
            shape=(len(xyz), 3, 3),
            warn_on_cast=False,
        )
    else:
        box = None
        orthogonal = False

    lazy_results = []
    current_frame = 0
    for frames in xyz.chunks[0]:
        next_frame = current_frame + frames
        if box is not None:
            current_box = box[current_frame:next_frame]
            orthogonal = da.allclose(
                traj.unitcell_angles[current_frame:next_frame], 90)
        else:
            current_box = None
        chunk_size = (frames, atoms)
        lazy_results.append(
            wrap_da(_compute_angles_chunk,
                    chunk_size,
                    xyz=xyz[current_frame:next_frame],
                    triplets=triplets,
                    box=current_box,
                    orthogonal=orthogonal,
                    opt=opt,
                    **kwargs))
        current_frame = next_frame
    max_result = da.concatenate(lazy_results)
    result = max_result[:length]
    return result
Exemple #14
0
def reorder_mpas_data(ds, var, client, comp, path_zarr):
    nCells = 41943042
    perm_arr = np.fromfile(
        f'/glade/work/haiyingx/mpas_655362/mc2gv.dat.{nCells}', dtype='i4')
    print(perm_arr.shape)
    [future] = client.scatter([perm_arr], broadcast=True)
    arr_shape = ds[var].data.shape
    print(var, ds[var].dims, arr_shape)
    if len(ds[var].dims) == 3:
        var_arr = da.transpose(ds[var].data, (0, 2, 1))
    else:
        var_arr = ds[var].data
    arr_size = var_arr.nbytes
    """Using Ellipsis ... here to deal with both 2D and 3D variables"""
    reindex_arr = da.map_blocks(lambda x, y: x[..., y],
                                var_arr,
                                perm_arr,
                                dtype='f4')
    """Only pad the last dimension"""
    padded_tuple = ((0, 0), ) * (len(ds[var].dims) - 1) + ((0, 2046), )
    padded_arr = da.pad(reindex_arr, padded_tuple, 'constant')
    print('var', var, padded_tuple)
    # arr = padded_arr.reshape(padded_arr.shape[0],padded_arr.shape[1],-1,2048)
    arr = padded_arr.reshape(padded_arr.shape[:-1] + (20481, 2048))
    print(padded_arr.shape[:-1])
    """Use persist() can save in the memory and speed up when call compute()"""
    pre_b = arr.mean().persist()
    print(arr.shape)
    encoding = {f'{var}': {'compressor': comp[var]}}
    ds = xr.DataArray(arr, name=f'{var}').to_dataset()
    filename = f'{path_zarr[:-4]}{var}.zarr'
    if exists(filename):
        shutil.rmtree(filename)
    ds.to_zarr(filename, encoding=encoding)
    """Read the compressed file to get mean(), and compare abs tol"""
    filesize = sum(p.stat().st_size for p in Path(filename).rglob('*'))
    decomp_f = xr.open_zarr(filename)
    decomp_arr = decomp_f[var].data
    print(comp[var])
    if comp[var].codec_id == 'zfpy':
        tol = comp[var].tolerance
    else:
        tol = comp[var].level
    a = da.allclose(decomp_arr, arr, rtol=0.0, atol=tol).persist()
    b = decomp_f[var].mean().persist()
    """Save metric info to csv file"""
    results = []
    res_dict = {}
    res_dict['var_name'] = var
    res_dict['orig_size'] = arr_size
    res_dict['recon_size'] = filesize
    res_dict['ratio'] = round(arr_size / filesize, 2)
    res_dict['abs_valid'] = a.compute()
    res_dict['orig_avg'] = f'{pre_b.compute():.7f}'
    res_dict['recon_avg'] = f'{b.compute().values:.7f}'
    results.append(res_dict)
    pd.DataFrame(results).to_csv(
        '/glade/scratch/haiyingx/Falko/hybrid/size.txt',
        index=False,
        sep=',',
        mode='a',
        header=False,
    )
def check_split_output_hdf5(input_filepath,
                            output_filepath,
                            logic_cs,
                            input_dset_key='/data',
                            output_dset_keyprefix='/data'):
    """ Compare the real chunks to the splits to see if split process went well. 
    Load the input array, successively extract one chunk and compare it to the corresponding dataset in output file.
    By default we suppose that dataset keys in output file are of the form: prefix + id (integer).

    Arguments:
    ----------
        input_filepath: hdf5 file with (at least) 1 dataset containing the multidim array that have been splitted.
        output_filepath: hdf5 file containing the splitted input_file (1 dataset per split).
        logic_cs: logic chunk shape used for the split
        input_dset_key: dataset key where input array is stored.
        output_dset_keyprefix: prefix for the keys of datasets containing splits in output file.
    """
    def sanity_check(file_path):
        logger.debug(f"\nChecking file integrity: {file_path}")
        if os.path.isfile(file_path):
            logger.debug(f'File has been found.')
        try:
            apply_sanity_check(file_path)
            logger.debug("Sanity check passed.")
            return True
        except:
            logger.debug('-' * 60)
            traceback.print_exc(file=sys.stdout)
            logger.debug('-' * 60)
            logger.debug(
                "Sanity check failed, aborting goodness of split checking.")
            return False

    def apply_sanity_check(file_path):
        """ Check if splitted file not empty.
        """
        with h5py.File(file_path, 'r') as f:
            logger.debug("file object %s", f)
            logger.debug("keys %s", list(f.keys()))
            assert len(list(f.keys())) != 0
        logger.debug("Integrity check passed.")

    if not sanity_check(input_filepath) or not sanity_check(output_filepath):
        return False

    logger.debug(f"\nChecking files data...")
    input_arr = get_dask_array_from_hdf5(input_filepath,
                                         input_dset_key,
                                         to_da=True,
                                         logic_cs=logic_cs)
    input_arr_list = get_arr_chunks(input_arr)
    nb_chunks = len(input_arr_list)
    with h5py.File(output_filepath, 'r') as split_file:
        for i, a in enumerate(input_arr_list):
            stored_a = da.from_array(split_file[output_dset_keyprefix +
                                                str(i)])
            logger.debug("Stored split shape: %s ", stored_a.shape)
            stored_a.rechunk(chunks=logic_cs)
            logger.debug("Split rechunked to: %s", stored_a.shape)
            logger.debug("Original data chunk: %s", a.shape)
            logger.debug("Testing all close...")
            test = da.allclose(stored_a, a)
            if test.compute():
                logger.debug(f"Test {i+1}/{nb_chunks} passed.")
            else:
                logger.debug(f"Test {i+1}/{nb_chunks} failed.")
    return True