Пример #1
0
def check_dmd_dask(D, mu, Phi, show_warning=True):
    """
        Checks how close the approximation using DMD is to the original data.

        Returns:
            None if the difference is within the tolerance
            Displays a warning otherwise.
    """
    X = D[:, 0:-1]
    Y = D[:, 1:]
    #Y_est = da.dot(da.dot(da.dot(Phi, da.diag(mu)), pinv_SVD(Phi)), X)
    Phi_inv = pinv_SVD(Phi)
    PhiMu = da.dot(Phi, da.diag(mu))
    #Y_est = da.dot(da.dot(PhiMu, Phi_inv), X)
    Y_est = da.dot(PhiMu, da.dot(Phi_inv, X))
    diff = da.real(Y - Y_est)
    res = da.fabs(diff)
    rtol = 1.e-8
    atol = 1.e-5

    if da.all(res < atol + rtol * da.fabs(da.real(Y_est))).compute():
        return (None)
    else:
        #if not b and show_warning:
        warn('dmd result does not satisfy Y=AX')
Пример #2
0
    def get_angle_info(vza, sza, raa, m_pi):
        """
        Gets the angle information
        """

        AngleInfo = namedtuple('AngleInfo',
                               'vza sza raa vza_rad sza_rad raa_rad')

        # View zenith angle
        vza_rad = da.deg2rad(vza)

        # Solar zenith angle
        sza_rad = da.deg2rad(sza)

        # Relative azimuth angle
        raa_rad = da.deg2rad(raa)

        vza_abs = da.fabs(vza_rad)
        sza_abs = da.fabs(sza_rad)

        raa_abs = da.where((vza_rad < 0) | (sza_rad < 0), m_pi, raa_rad)

        return AngleInfo(vza=vza,
                         sza=sza,
                         raa=raa,
                         vza_rad=vza_abs,
                         sza_rad=sza_abs,
                         raa_rad=raa_abs)
Пример #3
0
    def grid_glm_data(self, flashes):
        """
        Aggregate the point flashes into a grid of flash counts occurring within each grid box.

        Args:
            flashes (:class:`pandas.DataFrame`): Contains the longitudes and latitudes of each flash

        Returns:
            :class:`numpy.ndarray` [y, x]: The number of flashes occurring at each grid point.
        """
        flash_x, flash_y = self.glm_proj(flashes["flash_lon"].values,
                                         flashes["flash_lat"].values)
        flash_x /= 1000
        flash_y /= 1000
        valid_flashes = np.where(
            (flash_x >= self.x_points.min() - self.dx_km / 2)
            & (flash_x <= self.x_points.max() + self.dx_km / 2)
            & (flash_y >= self.y_points.min() - self.dx_km / 2)
            & (flash_y <= self.y_points.max() + self.dx_km / 2))[0]
        if valid_flashes.size > 0:
            if PARALLEL:
                x_grid_flat = da.from_array(self.x_grid.reshape(
                    (self.x_grid.size, 1)),
                                            chunks=512)
                y_grid_flat = da.from_array(self.y_grid.reshape(
                    (self.x_grid.size, 1)),
                                            chunks=512)
                flash_x_flat = da.from_array(flash_x[valid_flashes].reshape(
                    1, valid_flashes.size),
                                             chunks=512)
                flash_y_flat = da.from_array(flash_y[valid_flashes].reshape(
                    1, valid_flashes.size),
                                             chunks=512)
                x_dist = da.fabs(x_grid_flat - flash_x_flat)
                y_dist = da.fabs(y_grid_flat - flash_y_flat)
                flash_grid_counts = da.sum(
                    (x_dist <= self.dx_km / 2) & (y_dist <= self.dx_km / 2),
                    axis=1)
                flash_grid = flash_grid_counts.reshape(
                    self.lon_grid.shape).astype(np.int32).compute()
            else:
                x_grid_flat = self.x_grid.reshape((self.x_grid.size, 1))
                y_grid_flat = self.y_grid.reshape((self.x_grid.size, 1))
                flash_x_flat = flash_x[valid_flashes].reshape(
                    1, valid_flashes.size)
                flash_y_flat = flash_y[valid_flashes].reshape(
                    1, valid_flashes.size)
                x_dist = np.abs(x_grid_flat - flash_x_flat)
                y_dist = np.abs(y_grid_flat - flash_y_flat)
                flash_grid_counts = np.sum(
                    (x_dist <= self.dx_km / 2) & (y_dist <= self.dx_km / 2),
                    axis=1)
                flash_grid = flash_grid_counts.reshape(
                    self.lon_grid.shape).astype(np.int32)
        else:
            flash_grid = np.zeros(self.lon_grid.shape, dtype=np.int32)
        return flash_grid
Пример #4
0
def initialize_da(X, k, init='random', W=None, H=None):
    n_components = k
    n_samples, n_features = X.shape
    if init == 'random':
        avg = da.sqrt(X.mean() / n_components)
        H = avg * da.random.RandomState(42).normal(
            0,
            1,
            size=(n_components, n_features),
            chunks=(n_components, X.chunks[1][0]))
        W = avg * da.random.RandomState(42).normal(
            0,
            1,
            size=(n_samples, n_components),
            chunks=(n_samples, n_components))

        H = da.fabs(H)
        W = da.fabs(W)
        return W, H

    if init == 'nndsvd' or init == 'nndsvda':
        # not converted to da yet
        raise NotImplementedError

    if init == 'custom':
        return W, H

    if init == 'random_vcol':

        import math
        #p_c = options.get('p_c', int(ceil(1. / 5 * X.shape[1])))
        #p_r = options.get('p_r', int(ceil(1. / 5 * X.shape[0])))
        p_c = int(math.ceil(1. / 5 * X.shape[1]))
        p_r = int(math.ceil(1. / 5 * X.shape[0]))
        prng = np.random.RandomState(42)

        #W = da.zeros((X.shape[0], n_components), chunks = (X.shape[0],n_components))
        #H = da.zeros((n_components, X.shape[1]), chunks = (n_components,X.chunks[1][0]))

        W = []
        H = []

        for i in range(k):
            W.append(X[:, prng.randint(low=0, high=X.shape[1], size=p_c)].mean(
                axis=1).compute())
            H.append(X[prng.randint(low=0, high=X.shape[0], size=p_r), :].mean(
                axis=0).compute())
        W = np.stack(W, axis=1)
        H = np.stack(H, axis=0)

        return W, H
Пример #5
0
def _transformlng_dask(lng, lat):
    ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \
          0.1 * lng * lat + 0.1 * da.sqrt(da.fabs(lng))
    ret += (20.0 * da.sin(6.0 * lng * pi) +
            20.0 * da.sin(2.0 * lng * pi)) * 2.0 / 3.0
    ret += (20.0 * da.sin(lng * pi) +
            40.0 * da.sin(lng / 3.0 * pi)) * 2.0 / 3.0
    ret += (150.0 * da.sin(lng / 12.0 * pi) +
            300.0 * da.sin(lng / 30.0 * pi)) * 2.0 / 3.0
    return ret
Пример #6
0
def _transformlat_dask(lng, lat):
    ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \
          0.1 * lng * lat + 0.2 * da.sqrt(da.fabs(lng))
    ret += (20.0 * da.sin(6.0 * lng * pi) +
            20.0 * da.sin(2.0 * lng * pi)) * 2.0 / 3.0
    ret += (20.0 * da.sin(lat * pi) +
            40.0 * da.sin(lat / 3.0 * pi)) * 2.0 / 3.0
    ret += (160.0 * da.sin(lat / 12.0 * pi) +
            320 * da.sin(lat * pi / 30.0)) * 2.0 / 3.0
    return ret
Пример #7
0
    def get_li(self, kernel_type, li_recip):

        # relative azimuth angle
        # ensure it is in a [0,2] pi range
        phi = da.fabs(
            (self.angle_info.raa_rad % (2.0 * self.global_args.m_pi)))

        cos_phi = da.cos(phi)
        sin_phi = da.sin(phi)

        tanti = da.tan(self.angle_info.sza_rad)
        tantv = da.tan(self.angle_info.vza_rad)

        cos1, sin1, tan1 = self.get_pangles(tantv, self.global_args.br,
                                            self.global_args.nearly_zero)
        cos2, sin2, tan2 = self.get_pangles(tanti, self.global_args.br,
                                            self.global_args.nearly_zero)

        # sets cos & sin phase angle terms
        cos_phaang, phaang, sin_phaang = self.get_phaang(
            cos1, cos2, sin1, sin2, cos_phi)
        distance = self.get_distance(tan1, tan2, cos_phi)
        overlap_info = self.get_overlap(cos1, cos2, tan1, tan2, sin_phi,
                                        distance, self.global_args.hb,
                                        self.global_args.m_pi)

        if kernel_type.lower() == 'sparse':

            if li_recip:
                li = overlap_info.overlap - overlap_info.temp + 0.5 * (
                    1.0 + cos_phaang) / cos1 / cos2
            else:
                li = overlap_info.overlap - overlap_info.temp + 0.5 * (
                    1.0 + cos_phaang) / cos1

        else:

            if kernel_type.lower() == 'dense':

                if li_recip:
                    li = (1.0 + cos_phaang) / (
                        cos1 * cos2 *
                        (overlap_info.temp - overlap_info.overlap)) - 2.0
                else:
                    li = (1.0 + cos_phaang) / (
                        cos1 *
                        (overlap_info.temp - overlap_info.overlap)) - 2.0

        return li
Пример #8
0
    def _radiance_to_reflectance(self, rad):
        """Convert VIS radiance to reflectance factor.

        Note: Produces huge reflectances in situations where both radiance and
        solar zenith angle are small. Maybe the corresponding uncertainties
        can be used to filter these cases before calculating reflectances.

        Reference: [PUG], equation (6).
        """
        sza = self.solar_zenith_angle.where(
            da.fabs(self.solar_zenith_angle) < 90,
            np.float32(np.nan))  # direct illumination only
        cos_sza = np.cos(np.deg2rad(sza))
        refl = ((np.pi * self.coefs['distance_sun_earth']**2) /
                (self.coefs['solar_irradiance'] * cos_sza) * rad)
        return self.refl_factor_to_percent(refl)
Пример #9
0
    def calc_moments(self):
        with h5py.File(self.infile, 'r', rdcc_nbytes=1000 * 1000 * 1000) as f:
            data = da.from_array(f['data'],
                                 chunks=(-1, 256, -1, -1))  # CNHW layout
            data = da.transpose(data, (1, 2, 3, 0))
            dtype = data.dtype

            if dtype != np.float32:
                print(
                    'WARNING: data will be saved as float32 but input ist float64!'
                )

            if self.mean is None:
                arr = data
                with ProgressBar():
                    self.mean, self.std = da.compute(arr.mean(axis=[0, 1, 2]),
                                                     arr.std(axis=[0, 1, 2]),
                                                     num_workers=8)
            else:
                self.mean, self.std = np.asarray(
                    self.mean, dtype=dtype), np.asarray(self.std, dtype=dtype)

            print('mean: {}, std: {}'.format(list(self.mean), list(self.std)))

            if self.log1p_norm:
                data_z_norm = (data - self.mean) / self.std
                data_log1p = da.sign(data_z_norm) * da.log1p(
                    da.fabs(data_z_norm))

                if self.mean_log1p is None:
                    arr = data_log1p
                    with ProgressBar():
                        self.mean_log1p, self.std_log1p = da.compute(
                            arr.mean(axis=[0, 1, 2]),
                            arr.std(axis=[0, 1, 2]),
                            num_workers=8)
                else:
                    self.mean_log1p, self.std_log1p = np.asarray(
                        self.mean_log1p,
                        dtype=dtype), np.asarray(self.std_log1p, dtype=dtype)

                print('mean_log1p: {}, std_log1p: {}'.format(
                    list(self.mean_log1p), list(self.std_log1p)))
Пример #10
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2,))
    b = da.from_array(y, chunks=(2,))
    c = da.from_array(z, chunks=(2,))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a ** b, x ** y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a ** 2, x ** 2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2 ** b, 2 ** y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b/10), np.arcsin(y/10))
    assert eq(da.arccos(b/10), np.arccos(y/10))
    assert eq(da.arctan(b/10), np.arctan(y/10))
    assert eq(da.arctan2(b*10, a), np.arctan2(y*10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b*10), np.arcsinh(y*10))
    assert eq(da.arccosh(b*10), np.arccosh(y*10))
    assert eq(da.arctanh(b/10), np.arctanh(y/10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
Пример #11
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2, ))
    b = da.from_array(y, chunks=(2, ))
    c = da.from_array(z, chunks=(2, ))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a**b, x**y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a**2, x**2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2**b, 2**y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b / 10), np.arcsin(y / 10))
    assert eq(da.arccos(b / 10), np.arccos(y / 10))
    assert eq(da.arctan(b / 10), np.arctan(y / 10))
    assert eq(da.arctan2(b * 10, a), np.arctan2(y * 10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b * 10), np.arcsinh(y * 10))
    assert eq(da.arccosh(b * 10), np.arccosh(y * 10))
    assert eq(da.arctanh(b / 10), np.arctanh(y / 10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
Пример #12
0
    def ds(self):
        if self._ds is None:
            file_exists = os.path.exists(self._result_file)

            reprocess = not file_exists or self._reprocess

            if reprocess:
                if file_exists:
                    print('Old file exists ' + self._result_file)
                    #print('Removing old file ' + self._result_file)
                    #shutil.rmtree(self._result_file)

                ds_data = OrderedDict()

                to_seconds = np.vectorize(
                    lambda x: x.seconds + x.microseconds / 1E6)

                print('Processing binary data...')
                xx, yy, zz = self._loadgrid()
                if xx is None:
                    if self._from_nc:
                        print('Processing existing netcdf...')
                        fn = self._result_file[:-5] + '_QC_raw.nc'
                        if os.path.exists(fn):
                            ds_temp = xr.open_dataset(self._result_file[:-5] +
                                                      '_QC_raw.nc',
                                                      chunks={'time': 50})
                            u = da.transpose(ds_temp['U'].data,
                                             axes=[3, 0, 1, 2])
                            v = da.transpose(ds_temp['V'].data,
                                             axes=[3, 0, 1, 2])
                            w = da.transpose(ds_temp['W'].data,
                                             axes=[3, 0, 1, 2])
                            tt = ds_temp['time']
                            te = (tt - tt[0]) / np.timedelta64(1, 's')
                            xx = ds_temp['x'].values
                            yy = ds_temp['y'].values
                            zz = ds_temp['z'].values
                        else:
                            print('USING OLD ZARR DATA')
                            ds_temp = xr.open_zarr(self._result_file)
                            u = da.transpose(ds_temp['U'].data,
                                             axes=[3, 0, 1, 2])
                            v = da.transpose(ds_temp['V'].data,
                                             axes=[3, 0, 1, 2])
                            w = da.transpose(ds_temp['W'].data,
                                             axes=[3, 0, 1, 2])
                            tt = ds_temp['time']
                            te = (tt - tt[0]) / np.timedelta64(1, 's')
                            xx = ds_temp['x'].values
                            yy = ds_temp['y'].values
                            zz = ds_temp['z'].values
                            print('ERROR: No NetCDF data found for ' +
                                  self._xml_file)
                            #return None
                            # print(u.shape)

                else:
                    tt, uvw = self._loaddata(xx, yy, zz)
                    if tt is None:
                        print('ERROR: No binary data found for ' +
                              self._xml_file)
                        return None

                    # calculate the elapsed time from the Timestamp objects and then convert to datetime64 datatype
                    te = to_seconds(tt - tt[0])
                    tt = pd.to_datetime(tt)
                    uvw = uvw.persist()
                    u = uvw[:, :, :, :, 0]
                    v = uvw[:, :, :, :, 1]
                    w = uvw[:, :, :, :, 2]


#                    u = xr.DataArray(uvw[:,:,:,:,0], coords=[tt, xx, yy, zz], dims=['time','x', 'y', 'z'],
#                                     name='U', attrs={'standard_name': 'sea_water_x_velocity', 'units': 'm s-1'})
#                    v = xr.DataArray(uvw[:,:,:,:,1], coords=[tt, xx, yy, zz], dims=['time', 'x', 'y', 'z'],
#                                     name='V', attrs={'standard_name': 'sea_water_x_velocity', 'units': 'm s-1'})
#                    w = xr.DataArray(uvw[:,:,:,:,2], coords=[tt, xx, yy, zz], dims=['time', 'x', 'y', 'z'],
#                                     name='W', attrs={'standard_name': 'upward_sea_water_velocity', 'units': 'm s-1'})

                if xx is None:
                    print('No data found')
                    return None

                u = u.persist()
                v = v.persist()
                w = w.persist()

                dx = float(xx[1] - xx[0])
                dy = float(yy[1] - yy[0])
                dz = float(zz[1] - zz[0])

                if self._norm_dims:
                    exp = self._result_root.split('/')[4]
                    runSheet = pd.read_csv('~/RunSheet-%s.csv' % exp)
                    runSheet = runSheet.set_index('RunID')
                    runDetails = runSheet.ix[int(self.run_id[-2:])]

                    T = runDetails['T (s)']
                    h = runDetails['h (m)']
                    D = runDetails['D (m)']

                    ww = te / T
                    om = 2. * np.pi / T
                    d_s = (2. * 1E-6 / om)**0.5
                    bl = 3. * np.pi / 4. * d_s

                    if exp == 'Exp6':
                        if D == 0.1:
                            dy_c = (188. + 82.) / 2
                            dx_c = 39.25
                            cx = dx_c / 1000.
                            cy = dy_c / 1000.
                        else:
                            dy_c = (806. + 287.) / 2. * 0.22
                            dx_c = 113 * 0.22
                            cx = dx_c / 1000.
                            cy = dy_c / 1000.
                    elif exp == 'Exp8':
                        dy_c = 624 * 0.22
                        dx_c = 15
                        cx = dx_c / 1000.
                        cy = dy_c / 1000.
                    xn = (xx + (D / 2. - cx)) / D
                    yn = (yy - cy) / D
                    zn = zz / h

                    xnm, ynm = np.meshgrid(xn, yn)
                    rr = np.sqrt(xnm**2. + ynm**2)
                    cylMask = rr < 0.5

                    nanPlane = np.ones(cylMask.shape)
                    nanPlane[cylMask] = np.nan
                    nanPlane = nanPlane.T
                    nanPlane = nanPlane[np.newaxis, :, :, np.newaxis]

                    u = u * nanPlane
                    v = v * nanPlane
                    w = w * nanPlane

                    if D == 0.1:
                        xInds = xn > 3.
                    else:
                        xInds = xn > 2.

                    blInd = np.argmax(zn > bl / h)
                    blPlane = int(round(blInd))

                    Ue = u[:, xInds, :, :]
                    Ue_bar = da.nanmean(Ue, axis=(1, 2, 3)).compute()
                    Ue_bl = da.nanmean(Ue[:, :, :, blPlane],
                                       axis=(1, 2)).compute()

                    inds = ~np.isnan(Ue_bl)

                    xv = ww[inds] % 1.
                    xv = xv + np.random.normal(scale=1E-6, size=xv.shape)
                    yv = Ue_bl[inds]
                    xy = np.stack([
                        np.concatenate([xv - 1., xv, xv + 1.]),
                        np.concatenate([yv, yv, yv])
                    ]).T
                    xy = xy[xy[:, 0].argsort(), :]
                    xi = np.linspace(-0.5, 1.5, len(xv) / 8)
                    n = np.nanmax(xy[:, 1])
                    # print(n)
                    # fig,ax = pl.subplots()
                    # ax.scatter(xy[:,0],xy[:,1]/n)
                    # print(xy)
                    spl = si.LSQUnivariateSpline(xy[:, 0],
                                                 xy[:, 1] / n,
                                                 t=xi,
                                                 k=3)
                    roots = spl.roots()
                    der = spl.derivative()
                    slope = der(roots)
                    inds = np.min(np.where(slope > 0))
                    dt = (roots[inds] % 1.).mean() - 0.5

                    tpx = np.arange(0, 0.5, 0.001)
                    U0_bl = np.abs(spl(tpx + dt).min() * n)
                    ws = ww - dt
                    Ue_spl = spl((ws - 0.5) % 1.0 + dt) * n * -1.0

                    #maxima = spl.derivative().roots()
                    #Umax = spl(maxima)
                    #UminIdx = np.argmin(Umax)
                    #U0_bl = np.abs(Umax[UminIdx]*n)

                    #ww_at_min = maxima[UminIdx]
                    #ws = ww - ww_at_min + 0.25

                    inds = ~np.isnan(Ue_bar)

                    xv = ww[inds] % 1.
                    xv = xv + np.random.normal(scale=1E-6, size=xv.shape)
                    yv = Ue_bar[inds]
                    xy = np.stack([
                        np.concatenate([xv - 1., xv, xv + 1.]),
                        np.concatenate([yv, yv, yv])
                    ]).T
                    xy = xy[xy[:, 0].argsort(), :]
                    xi = np.linspace(-0.5, 1.5, len(xv) / 8)
                    n = np.nanmax(xy[:, 1])
                    spl = si.LSQUnivariateSpline(xy[:, 0],
                                                 xy[:, 1] / n,
                                                 t=xi,
                                                 k=4)
                    maxima = spl.derivative().roots()
                    Umax = spl(maxima)
                    UminIdx = np.argmin(Umax)
                    U0_bar = np.abs(Umax[UminIdx] * n)

                    ww = xr.DataArray(ww, coords=[
                        tt,
                    ], dims=[
                        'time',
                    ])
                    ws = xr.DataArray(ws - 0.5, coords=[
                        tt,
                    ], dims=[
                        'time',
                    ])

                    xn = xr.DataArray(xn, coords=[
                        xx,
                    ], dims=[
                        'x',
                    ])
                    yn = xr.DataArray(yn, coords=[
                        yy,
                    ], dims=[
                        'y',
                    ])
                    zn = xr.DataArray(zn, coords=[
                        zz,
                    ], dims=[
                        'z',
                    ])

                    Ue_bar = xr.DataArray(Ue_bar,
                                          coords=[
                                              tt,
                                          ],
                                          dims=[
                                              'time',
                                          ])
                    Ue_bl = xr.DataArray(Ue_bl, coords=[
                        tt,
                    ], dims=[
                        'time',
                    ])
                    Ue_spl = xr.DataArray(Ue_spl,
                                          coords=[
                                              tt,
                                          ],
                                          dims=[
                                              'time',
                                          ])

                    ds_data['ww'] = ww
                    ds_data['ws'] = ws

                    ds_data['xn'] = xn
                    ds_data['yn'] = yn
                    ds_data['zn'] = zn

                    ds_data['Ue_bar'] = Ue_bar
                    ds_data['Ue_bl'] = Ue_bl
                    ds_data['Ue_spl'] = Ue_spl

                te = xr.DataArray(te, coords=[
                    tt,
                ], dims=[
                    'time',
                ])

                dims = ['time', 'x', 'y', 'z']
                coords = [tt, xx, yy, zz]

                ds_data['U'] = xr.DataArray(u,
                                            coords=coords,
                                            dims=dims,
                                            name='U',
                                            attrs={
                                                'standard_name':
                                                'sea_water_x_velocity',
                                                'units': 'm s-1'
                                            })
                ds_data['V'] = xr.DataArray(v,
                                            coords=coords,
                                            dims=dims,
                                            name='V',
                                            attrs={
                                                'standard_name':
                                                'sea_water_x_velocity',
                                                'units': 'm s-1'
                                            })
                ds_data['W'] = xr.DataArray(w,
                                            coords=coords,
                                            dims=dims,
                                            name='W',
                                            attrs={
                                                'standard_name':
                                                'sea_water_x_velocity',
                                                'units': 'm s-1'
                                            })
                ds_data['te'] = te

                # stdV = da.nanstd(v)
                # stdW = da.nanstd(w)
                # thres=7.
                if 'U0_bl' in locals():
                    condition = (da.fabs(v) / U0_bl >
                                 1.5) | (da.fabs(w) / U0_bl > 0.6)
                    for var in ['U', 'V', 'W']:
                        ds_data[var].data = da.where(condition, np.nan,
                                                     ds_data[var].data)

                piv_step_frame = float(
                    self._xml_root.findall('piv/stepFrame')[0].text)

                print('Calculating tensor')
                # j = jacobianConv(ds.U, ds.V, ds.W, dx, dy, dz, sigma=1.5)
                j = jacobianDask(u, v, w, piv_step_frame, dx, dy, dz)
                print('Done')
                #j = da.from_array(j,chunks=(20,-1,-1,-1,-1,-1))

                #                j = jacobianDask(uvw[:,:,:,:,0],uvw[:,:,:,:,1], uvw[:,:,:,:,2], piv_step_frame, dx, dy, dz)
                jT = da.transpose(j, axes=[0, 1, 2, 3, 5, 4])

                #                j = j.persist()
                #                jT = jT.persist()

                jacobianNorm = da.sqrt(
                    da.nansum(da.nansum(j**2., axis=-1), axis=-1))

                strainTensor = (j + jT) / 2.
                vorticityTensor = (j - jT) / 2.

                strainTensorNorm = da.sqrt(
                    da.nansum(da.nansum(strainTensor**2., axis=-1), axis=-1))
                vorticityTensorNorm = da.sqrt(
                    da.nansum(da.nansum(vorticityTensor**2., axis=-1),
                              axis=-1))
                divergence = j[:, :, :, :, 0, 0] + j[:, :, :, :, 1,
                                                     1] + j[:, :, :, :, 2, 2]
                # print(divergence)
                omx = vorticityTensor[:, :, :, :, 2, 1] * 2.
                omy = vorticityTensor[:, :, :, :, 0, 2] * 2.
                omz = vorticityTensor[:, :, :, :, 1, 0] * 2.

                divNorm = divergence / jacobianNorm

                #                divNorm = divNorm.persist()

                #                divNorm_mean = da.nanmean(divNorm)
                #                divNorm_std = da.nanstd(divNorm)

                dims = ['x', 'y', 'z']
                comp = ['u', 'v', 'w']

                ds_data['jacobian'] = xr.DataArray(
                    j,
                    coords=[tt, xx, yy, zz, comp, dims],
                    dims=['time', 'x', 'y', 'z', 'comp', 'dims'],
                    name='jacobian')

                ds_data['jacobianNorm'] = xr.DataArray(
                    jacobianNorm,
                    coords=[tt, xx, yy, zz],
                    dims=['time', 'x', 'y', 'z'],
                    name='jacobianNorm')

                ds_data['strainTensor'] = xr.DataArray(
                    strainTensor,
                    coords=[tt, xx, yy, zz, comp, dims],
                    dims=['time', 'x', 'y', 'z', 'comp', 'dims'],
                    name='strainTensor')

                ds_data['vorticityTensor'] = xr.DataArray(
                    vorticityTensor,
                    coords=[tt, xx, yy, zz, comp, dims],
                    dims=['time', 'x', 'y', 'z', 'comp', 'dims'],
                    name='vorticityTensor')

                ds_data['vorticityNorm'] = xr.DataArray(
                    vorticityTensorNorm,
                    coords=[tt, xx, yy, zz],
                    dims=['time', 'x', 'y', 'z'],
                    name='vorticityNorm')

                ds_data['strainNorm'] = xr.DataArray(
                    strainTensorNorm,
                    coords=[tt, xx, yy, zz],
                    dims=['time', 'x', 'y', 'z'],
                    name='strainNorm')

                ds_data['divergence'] = xr.DataArray(
                    divergence,
                    coords=[tt, xx, yy, zz],
                    dims=['time', 'x', 'y', 'z'],
                    name='divergence')

                ds_data['omx'] = xr.DataArray(omx,
                                              coords=[tt, xx, yy, zz],
                                              dims=['time', 'x', 'y', 'z'],
                                              name='omx')

                ds_data['omy'] = xr.DataArray(omy,
                                              coords=[tt, xx, yy, zz],
                                              dims=['time', 'x', 'y', 'z'],
                                              name='omy')

                ds_data['omz'] = xr.DataArray(omz,
                                              coords=[tt, xx, yy, zz],
                                              dims=['time', 'x', 'y', 'z'],
                                              name='omz')

                ds_data['divNorm'] = xr.DataArray(divNorm,
                                                  coords=[tt, xx, yy, zz],
                                                  dims=['time', 'x', 'y', 'z'],
                                                  name='divNorm')

                #                ds_data['divNorm_mean'] = xr.DataArray(divNorm_mean)
                #                ds_data['divNorm_std'] = xr.DataArray(divNorm_std)

                ds = xr.Dataset(ds_data)
                #                if self._from_nc:
                #                    for k,v in ds_temp.attrs.items():
                #                        ds.attrs[k]=v
                #ds = ds.chunk({'time': 20})

                self._append_CF_attrs(ds)
                self._append_attrs(ds)
                ds.attrs['filename'] = self._result_file

                if self._norm_dims:

                    KC = U0_bl * T / D
                    delta = (2. * np.pi * d_s) / h
                    S = delta / KC

                    ds.attrs['T'] = T
                    ds.attrs['h'] = h
                    ds.attrs['D'] = D
                    ds.attrs['U0_bl'] = U0_bl
                    ds.attrs['U0_bar'] = U0_bar
                    ds.attrs['KC'] = KC
                    ds.attrs['S'] = S
                    ds.attrs['Delta+'] = ((1E-6 * T)**0.5) / h
                    ds.attrs['Delta_l'] = 2 * np.pi * d_s
                    ds.attrs['Delta_s'] = d_s
                    ds.attrs['Re_D'] = U0_bl * D / 1E-6
                    ds.attrs['Beta'] = D**2. / (1E-6 * T)

                delta = (ds.attrs['dx'] * ds.attrs['dy'] *
                         ds.attrs['dz'])**(1. / 3.)
                dpx = (ds.attrs['pdx'] * ds.attrs['pdy'] *
                       ds.attrs['pdz'])**(1. / 3.)
                delta_px = delta / dpx
                dt = ds.attrs['piv_step_ensemble']

                #                divRMS = da.sqrt(da.nanmean((divergence * dt) ** 2.))
                #                divRMS = divRMS.persist()
                #                vorticityTensorNorm.persist()
                #                velocityError = divRMS/((3./(2.*delta_px**2.))**0.5)
                # print(da.percentile(ds_new['vorticityTensorNorm'].data.ravel(),99.))
                # print(ds_new['divRMS'])
                # print(ds_new['divNorm_mean'])
                #                vorticityError = divRMS/dt/da.percentile(vorticityTensorNorm.ravel(),99.)

                #                divNorm_mean = da.nanmean(divNorm)
                #                divNorm_std = da.nanstd(divNorm)

                # print("initial save")
                #ds.to_zarr(self._result_file,compute=False)
                #ds = xr.open_zarr(self._result_file)

                #                xstart = np.argmax(xx > 0.05)
                #                ystart = np.argmax(yy > 0.07)

                divRMS = da.sqrt(da.nanmean(
                    (divergence * dt)**2.))  #.compute()
                #divNorm = divergence / jacobianNorm
                #divNorm = divNorm.compute()
                #divNorm_mean = da.nanmean(divNorm).compute()
                #divNorm_std = da.nanstd(divNorm).compute()
                velocityError = divRMS / ((3. / (2. * delta_px**2.))**0.5)
                vortNorm = vorticityTensorNorm  #.compute()

                vorticityError = divRMS / dt / np.percentile(
                    vortNorm.ravel(), 99.)

                velocityError, vorticityError = da.compute(
                    velocityError, vorticityError)

                #ds.attrs['divNorm_mean'] = divNorm_mean
                #ds.attrs['divNorm_std'] = divNorm_std
                ds.attrs['velocityError'] = velocityError
                ds.attrs['vorticityError'] = vorticityError

                if self._norm_dims:
                    xInds = (xn > 0.5) & (xn < 2.65)
                    yInds = (yn > -0.75) & (yn < 0.75)
                else:
                    xInds = range(len(ds['x']))
                    yInds = range(len(ds['y']))
                vrms = (ds['V'][:, xInds, yInds, :]**2.).mean(
                    dim=['time', 'x', 'y', 'z'])**0.5
                wrms = (ds['W'][:, xInds, yInds, :]**2.).mean(
                    dim=['time', 'x', 'y', 'z'])**0.5
                ds.attrs['Vrms'] = float(vrms.compute())
                ds.attrs['Wrms'] = float(wrms.compute())

                #fig,ax = pl.subplots()
                #ax.plot(ds.ws,ds.Ue_spl/U0_bl,color='k')
                #ax.plot(ds.ws,ds.Ue_bl/U0_bl,color='g')
                #ax.set_xlabel(r'$t/T$')
                #ax.set_ylabel(r'$U_{bl}/U_0$')
                #fig.savefig(self._result_file[:-4] + 'png',dpi=125)
                #pl.close(fig)
                # print("second save")
                #ds.to_netcdf(self._result_file)
                ds.to_zarr(self._result_file, mode='w')

                print('Cached ' + self._result_file)

                #ds = xr.open_dataset(self._result_file,chunks={'time':20})
                ds = xr.open_zarr(self._result_file)
                ds.attrs['filename'] = self._result_file
            else:
                #ds = xr.open_dataset(self._result_file,chunks={'time':20})
                ds = xr.open_zarr(self._result_file)
                ds.attrs['filename'] = self._result_file

            self._ds = ds

        return self._ds
Пример #13
0
def update_H_da(M, H, W):
    denominator = da.dot(W.T, da.dot(W, H))
    denominator_new = da.where(
        da.fabs(denominator) < EPSILON, EPSILON, denominator)
    H_new = H * da.dot(W.T, M) / denominator_new
    return (H_new)
Пример #14
0
def gradient(image):
    return da.where(
            da.fabs(image) <= huber['threshold'],
            2 * image,
            2 * huber['threshold'] * da.sign(image))
Пример #15
0
def update_W_da(M, H, W):
    denominator = da.dot(W, da.dot(H, H.T))
    denominator_new = da.where(
        da.fabs(denominator) < EPSILON, EPSILON, denominator)
    W_new = W * da.dot(M, H.T) / denominator_new
    return (W_new)
Пример #16
0
def main():
    ########################################################
    
    infile = '/data/ERA5/hdf5_hr/ds_eval_2000_to_2002.hdf5'
    outfile = 'sr_eval_2000_2002.{}.tfrecords'
    n_files = 8
    gzip = True
    shuffle = False

    ########################################################

    SR_ratio = 4
    log1p_norm = True
    z_norm = False
    
    mean, std = [2.0152406e-08, 2.1581373e-07], [2.8560082e-05, 5.0738556e-05]
    mean_log1p, std_log1p = [0.008315503, 0.0028762482], [0.5266841, 0.5418187]

    #########################################################

    HR_reduce_latitude = 0#107  # 15 deg from each pole
    patch_size = None#(96, 96)
    n_patches = 50
    mode = 'train'

    ########################################################

    f = h5py.File(infile, 'r', rdcc_nbytes=1000*1000*1000)
    data = da.from_array(f['data'], chunks=(-1, 16 if shuffle else 256, -1, -1))  # CNHW layout
    data = da.transpose(data, (1,2,3,0))
    dtype = data.dtype

    if dtype != np.float32:
        print('WARNING: data will be saved as float32 but input ist float64!')

    if mean is None:
        arr = data
        with ProgressBar():
            mean, std = da.compute(arr.mean(axis=[0,1,2]), arr.std(axis=[0,1,2]), num_workers=8)
    else:
        mean, std = np.asarray(mean, dtype=dtype), np.asarray(std, dtype=dtype)

    print('mean: {}, std: {}'.format(list(mean), list(std)))

    if log1p_norm:
        data_z_norm = (data-mean) / std
        data_log1p = da.sign(data_z_norm) * da.log1p(da.fabs(data_z_norm))

        if mean_log1p is None:
            arr = data_log1p
            with ProgressBar():
                mean_log1p, std_log1p = da.compute(arr.mean(axis=[0,1,2]), arr.std(axis=[0,1,2]), num_workers=8)
        else:
            mean_log1p, std_log1p = np.asarray(mean_log1p, dtype=dtype), np.asarray(std_log1p, dtype=dtype)

        print('mean_log1p: {}, std_log1p: {}'.format(list(mean_log1p), list(std_log1p)))

        data = data_log1p
    elif z_norm:
        data = (data-mean) / std

    if shuffle:
        block_indices = np.random.permutation(data.numblocks[0])
    else:
        block_indices = np.arange(data.numblocks[0])


    file_blocks = np.array_split(block_indices, n_files)
    i = 0
    for n, indices in enumerate(file_blocks):
        if n_files > 1:
            name = outfile.format(n)
        else:
            name = outfile
        
        with tf.io.TFRecordWriter(name, options='ZLIB' if gzip else None) as writer:
            for block_idx in indices:
                block =  data.blocks[block_idx].compute()
                if shuffle:
                    block = np.random.permutation(block)

                if HR_reduce_latitude:
                    lat_start = HR_reduce_latitude//2
                    block = block[:, lat_start:(-lat_start),:, :]

                generate_TFRecords(writer, block, SR_ratio, mode, patch_size, n_patches)
                i += 1
                print('{} / {}'.format(i, data.numblocks[0]), flush=True)
Пример #17
0
def _apply(func,
           datasets,
           chunk=CHUNK,
           pad=None,
           relabel=False,
           stack=False,
           compute=True,
           out=None,
           normalize=False,
           **kwargs):
    """
    Appplies a function to a given set of datasets. Wraps a standard
    function call of the form:

        func(*datasets, **kwargs)

    Named parameters gives extra functionality.

    Parameters
    ----------
    func: callable
        Function to be mapped across datasets.
    datasets: list of numpy array-like
        Input datasets.
    chunk: boolean
        If `True` then input datasets will be assumed tobe `Dask.Array`s and
        the function will be mapped across arrays blocks.
    pad: None, int or iterable
        The padding to apply (only if `chunk = True`). If `pad != None` then
        `dask.array.ghost.map_overlap` will be used to map the function across
        overlapping blocks, otherwise `dask.array.map_blocks` will be used.
    relabel: boolean
        Some of the labelling functions will yield local labelling if `chunk=True`.
        If `func` is a labelling function, set `relabel = True` to map the result
        for global consistency. See `survos2.improc.utils.dask_relabel_chunks` for
        more details.
    compute: boolean
        If `True` the result will be computed and returned in numpy array form,
        otherwise a `dask.delayed` will be returned if `chunk = True`.
    out: None or numpy array-like
        if `out != None` then the result will be stored in there.
    **kwargs: other keyword arguments
        Arguments to be passed to `func`.

    Returns
    -------
    result: numpy array-like
        The computed result if `compute = True` or `chunk = False`, the result
        of the lazy wrapping otherwise.
    """
    if stack and len(datasets) > 1:
        dataset = da.stack(datasets, axis=0)
        dataset = da.rechunk(dataset,
                             chunks=(dataset.shape[0], ) + dataset.chunks[1:])
        datasets = [dataset]

    if chunk == True:
        kwargs.setdefault('dtype', out.dtype if out else datasets[0].dtype)
        kwargs.setdefault('drop_axis', 0 if stack else None)
        if pad is None or pad == False:
            result = da.map_blocks(func, *datasets, **kwargs)
        elif len(datasets) == 1:
            if np.isscalar(pad):
                pad = [pad] * datasets[0].ndim

            if stack:
                pad[0] = 0  # don't pad feature channel
                depth = {i: d for i, d in enumerate(pad)}
                trim = {i: d for i, d in enumerate(pad[1:])}
            else:
                depth = trim = {i: d for i, d in enumerate(pad)}

            g = da.ghost.ghost(datasets[0], depth=depth, boundary='reflect')
            r = g.map_blocks(func, **kwargs)
            result = da.ghost.trim_internal(r, trim)
        else:
            raise ValueError('`pad` only works with single')

        rchunks = result.chunks

        if not relabel and normalize:
            result = result / da.nanmax(da.fabs(result))

        if out is not None:
            result.store(out, compute=True)
        elif compute:
            result = result.compute()

        if relabel:
            if out is not None:
                result = dask_relabel_chunks(da.from_array(out,
                                                           chunks=rchunks))
                result.store(out, compute=True)
            else:
                result = dask_relabel_chunks(
                    da.from_array(result, chunks=rchunks))
                if compute:
                    result = result.compute()
    else:
        result = func(*datasets, **kwargs)
        if out is not None:
            out[...] = result

    if out is None:
        return result