예제 #1
0
def xyz2lonlat(x__, y__, z__):
    """Get longitudes from cartesian coordinates.
    """
    R = 6370997.0
    lons = da.rad2deg(da.arccos(x__ / da.sqrt(x__ ** 2 + y__ ** 2))) * da.sign(y__)
    lats = da.sign(z__) * (90 - da.rad2deg(da.arcsin(da.sqrt(x__ ** 2 + y__ ** 2) / R)))

    return lons, lats
def xyz2lonlat(x__, y__, z__):
    """Get longitudes from cartesian coordinates.
    """
    R = 6370997.0
    lons = da.rad2deg(da.arccos(x__ / da.sqrt(x__ ** 2 + y__ ** 2))) * da.sign(y__)
    lats = da.sign(z__) * (90 - da.rad2deg(da.arcsin(da.sqrt(x__ ** 2 + y__ ** 2) / R)))

    return lons, lats
예제 #3
0
def test_incremental_basic(scheduler, dataframes):
    # Create observations that we know linear models can recover
    n, d = 100, 3
    rng = da.random.RandomState(42)
    X = rng.normal(size=(n, d), chunks=30)
    coef_star = rng.uniform(size=d, chunks=d)
    y = da.sign(X.dot(coef_star))
    y = (y + 1) / 2
    if dataframes:
        X = dd.from_array(X)
        y = dd.from_array(y)

    with scheduler() as (s, [_, _]):
        est1 = SGDClassifier(random_state=0, tol=1e-3, average=True)
        est2 = clone(est1)

        clf = Incremental(est1, random_state=0)
        result = clf.fit(X, y, classes=[0, 1])
        assert result is clf

        # est2 is a sklearn optimizer; this is just a benchmark
        if dataframes:
            X = X.to_dask_array(lengths=True)
            y = y.to_dask_array(lengths=True)

        for slice_ in da.core.slices_from_chunks(X.chunks):
            est2.partial_fit(X[slice_].compute(),
                             y[slice_[0]].compute(),
                             classes=[0, 1])

        assert isinstance(result.estimator_.coef_, np.ndarray)
        rel_error = np.linalg.norm(clf.coef_ - est2.coef_)
        rel_error /= np.linalg.norm(clf.coef_)
        assert rel_error < 0.9

        assert set(dir(clf.estimator_)) == set(dir(est2))

        #  Predict
        result = clf.predict(X)
        expected = est2.predict(X)
        assert isinstance(result, da.Array)
        if dataframes:
            # Compute is needed because chunk sizes of this array are unknown
            result = result.compute()
        rel_error = np.linalg.norm(result - expected)
        rel_error /= np.linalg.norm(expected)
        assert rel_error < 0.3

        # score
        result = clf.score(X, y)
        expected = est2.score(*dask.compute(X, y))
        assert abs(result - expected) < 0.1

        clf = Incremental(SGDClassifier(random_state=0, tol=1e-3,
                                        average=True))
        clf.partial_fit(X, y, classes=[0, 1])
        assert set(dir(clf.estimator_)) == set(dir(est2))
예제 #4
0
def test_incremental_basic(scheduler):
    # Create observations that we know linear models can recover
    n, d = 100, 3
    rng = da.random.RandomState(42)
    X = rng.normal(size=(n, d), chunks=30)
    coef_star = rng.uniform(size=d, chunks=d)
    y = da.sign(X.dot(coef_star))
    y = (y + 1) / 2

    with scheduler() as (s, [_, _]):
        est1 = SGDClassifier(random_state=0, tol=1e-3, average=True)
        est2 = clone(est1)

        clf = Incremental(est1, random_state=0)
        result = clf.fit(X, y, classes=[0, 1])
        for slice_ in da.core.slices_from_chunks(X.chunks):
            est2.partial_fit(X[slice_], y[slice_[0]], classes=[0, 1])

        assert result is clf

        assert isinstance(result.estimator_.coef_, np.ndarray)
        rel_error = np.linalg.norm(clf.coef_ - est2.coef_)
        rel_error /= np.linalg.norm(clf.coef_)
        assert rel_error < 0.9

        assert set(dir(clf.estimator_)) == set(dir(est2))

        #  Predict
        result = clf.predict(X)
        expected = est2.predict(X)
        assert isinstance(result, da.Array)
        rel_error = np.linalg.norm(result - expected)
        rel_error /= np.linalg.norm(expected)
        assert rel_error < 0.2

        # score
        result = clf.score(X, y)
        expected = est2.score(X, y)
        assert abs(result - expected) < 0.1

        clf = Incremental(SGDClassifier(random_state=0, tol=1e-3,
                                        average=True))
        clf.partial_fit(X, y, classes=[0, 1])
        assert set(dir(clf.estimator_)) == set(dir(est2))
예제 #5
0
    def calc_moments(self):
        with h5py.File(self.infile, 'r', rdcc_nbytes=1000 * 1000 * 1000) as f:
            data = da.from_array(f['data'],
                                 chunks=(-1, 256, -1, -1))  # CNHW layout
            data = da.transpose(data, (1, 2, 3, 0))
            dtype = data.dtype

            if dtype != np.float32:
                print(
                    'WARNING: data will be saved as float32 but input ist float64!'
                )

            if self.mean is None:
                arr = data
                with ProgressBar():
                    self.mean, self.std = da.compute(arr.mean(axis=[0, 1, 2]),
                                                     arr.std(axis=[0, 1, 2]),
                                                     num_workers=8)
            else:
                self.mean, self.std = np.asarray(
                    self.mean, dtype=dtype), np.asarray(self.std, dtype=dtype)

            print('mean: {}, std: {}'.format(list(self.mean), list(self.std)))

            if self.log1p_norm:
                data_z_norm = (data - self.mean) / self.std
                data_log1p = da.sign(data_z_norm) * da.log1p(
                    da.fabs(data_z_norm))

                if self.mean_log1p is None:
                    arr = data_log1p
                    with ProgressBar():
                        self.mean_log1p, self.std_log1p = da.compute(
                            arr.mean(axis=[0, 1, 2]),
                            arr.std(axis=[0, 1, 2]),
                            num_workers=8)
                else:
                    self.mean_log1p, self.std_log1p = np.asarray(
                        self.mean_log1p,
                        dtype=dtype), np.asarray(self.std_log1p, dtype=dtype)

                print('mean_log1p: {}, std_log1p: {}'.format(
                    list(self.mean_log1p), list(self.std_log1p)))
예제 #6
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2,))
    b = da.from_array(y, chunks=(2,))
    c = da.from_array(z, chunks=(2,))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a ** b, x ** y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a ** 2, x ** 2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2 ** b, 2 ** y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b/10), np.arcsin(y/10))
    assert eq(da.arccos(b/10), np.arccos(y/10))
    assert eq(da.arctan(b/10), np.arctan(y/10))
    assert eq(da.arctan2(b*10, a), np.arctan2(y*10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b*10), np.arcsinh(y*10))
    assert eq(da.arccosh(b*10), np.arccosh(y*10))
    assert eq(da.arctanh(b/10), np.arctanh(y/10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
예제 #7
0
    def _test_basic(c, s, a, b):
        rng = da.random.RandomState(42)

        n, d = (50, 2)
        # create observations we know linear models can fit
        X = rng.normal(size=(n, d), chunks=n // 2)
        coef_star = rng.uniform(size=d, chunks=d)
        y = da.sign(X.dot(coef_star))

        if array_type == "numpy":
            X, y = yield c.compute((X, y))

        params = {
            "loss":
            ["hinge", "log", "modified_huber", "squared_hinge", "perceptron"],
            "average": [True, False],
            "learning_rate": ["constant", "invscaling", "optimal"],
            "eta0":
            np.logspace(-2, 0, num=1000),
        }
        model = SGDClassifier(tol=-np.inf,
                              penalty="elasticnet",
                              random_state=42,
                              eta0=0.1)
        if library == "dask-ml":
            model = Incremental(model)
            params = {"estimator__" + k: v for k, v in params.items()}
        elif library == "ConstantFunction":
            model = ConstantFunction()
            params = {"value": np.linspace(0, 1, num=1000)}

        search = HyperbandSearchCV(model,
                                   params,
                                   max_iter=max_iter,
                                   random_state=42)
        classes = c.compute(da.unique(y))
        yield search.fit(X, y, classes=classes)

        if library == "dask-ml":
            X, y = yield c.compute((X, y))
        score = search.best_estimator_.score(X, y)
        assert score == search.score(X, y)
        assert 0 <= score <= 1

        if library == "ConstantFunction":
            assert score == search.best_score_
        else:
            # These are not equal because IncrementalSearchCV uses a train/test
            # split and we're testing on the entire train dataset, not only the
            # validation/test set.
            assert abs(score - search.best_score_) < 0.1

        assert type(search.best_estimator_) == type(model)
        assert isinstance(search.best_params_, dict)

        num_fit_models = len(set(search.cv_results_["model_id"]))
        num_pf_calls = sum([
            v[-1]["partial_fit_calls"] for v in search.model_history_.values()
        ])
        models = {9: 17, 15: 17, 20: 17, 27: 49, 30: 49, 81: 143}
        pf_calls = {9: 69, 15: 101, 20: 144, 27: 357, 30: 379, 81: 1581}
        assert num_fit_models == models[max_iter]
        assert num_pf_calls == pf_calls[max_iter]

        best_idx = search.best_index_
        if isinstance(model, ConstantFunction):
            assert search.cv_results_["test_score"][best_idx] == max(
                search.cv_results_["test_score"])
        model_ids = {h["model_id"] for h in search.history_}

        if math.log(max_iter, 3) % 1.0 == 0:
            # log(max_iter, 3) % 1.0 == 0 is the good case when max_iter is a
            # power of search.aggressiveness
            # In this case, assert that more models are tried then the max_iter
            assert len(model_ids) > max_iter
        else:
            # Otherwise, give some padding "almost as many estimators are tried
            # as max_iter". 3 is a fudge number chosen to be the minimum; when
            # max_iter=20, len(model_ids) == 17.
            assert len(model_ids) + 3 >= max_iter

        assert all("bracket" in id_ for id_ in model_ids)
예제 #8
0
def sign(A):
    return da.sign(A)
예제 #9
0
def test_arithmetic():
    x = np.arange(5).astype('f4') + 2
    y = np.arange(5).astype('i8') + 2
    z = np.arange(5).astype('i4') + 2
    a = da.from_array(x, chunks=(2, ))
    b = da.from_array(y, chunks=(2, ))
    c = da.from_array(z, chunks=(2, ))
    assert eq(a + b, x + y)
    assert eq(a * b, x * y)
    assert eq(a - b, x - y)
    assert eq(a / b, x / y)
    assert eq(b & b, y & y)
    assert eq(b | b, y | y)
    assert eq(b ^ b, y ^ y)
    assert eq(a // b, x // y)
    assert eq(a**b, x**y)
    assert eq(a % b, x % y)
    assert eq(a > b, x > y)
    assert eq(a < b, x < y)
    assert eq(a >= b, x >= y)
    assert eq(a <= b, x <= y)
    assert eq(a == b, x == y)
    assert eq(a != b, x != y)

    assert eq(a + 2, x + 2)
    assert eq(a * 2, x * 2)
    assert eq(a - 2, x - 2)
    assert eq(a / 2, x / 2)
    assert eq(b & True, y & True)
    assert eq(b | True, y | True)
    assert eq(b ^ True, y ^ True)
    assert eq(a // 2, x // 2)
    assert eq(a**2, x**2)
    assert eq(a % 2, x % 2)
    assert eq(a > 2, x > 2)
    assert eq(a < 2, x < 2)
    assert eq(a >= 2, x >= 2)
    assert eq(a <= 2, x <= 2)
    assert eq(a == 2, x == 2)
    assert eq(a != 2, x != 2)

    assert eq(2 + b, 2 + y)
    assert eq(2 * b, 2 * y)
    assert eq(2 - b, 2 - y)
    assert eq(2 / b, 2 / y)
    assert eq(True & b, True & y)
    assert eq(True | b, True | y)
    assert eq(True ^ b, True ^ y)
    assert eq(2 // b, 2 // y)
    assert eq(2**b, 2**y)
    assert eq(2 % b, 2 % y)
    assert eq(2 > b, 2 > y)
    assert eq(2 < b, 2 < y)
    assert eq(2 >= b, 2 >= y)
    assert eq(2 <= b, 2 <= y)
    assert eq(2 == b, 2 == y)
    assert eq(2 != b, 2 != y)

    assert eq(-a, -x)
    assert eq(abs(a), abs(x))
    assert eq(~(a == b), ~(x == y))
    assert eq(~(a == b), ~(x == y))

    assert eq(da.logaddexp(a, b), np.logaddexp(x, y))
    assert eq(da.logaddexp2(a, b), np.logaddexp2(x, y))
    assert eq(da.exp(b), np.exp(y))
    assert eq(da.log(a), np.log(x))
    assert eq(da.log10(a), np.log10(x))
    assert eq(da.log1p(a), np.log1p(x))
    assert eq(da.expm1(b), np.expm1(y))
    assert eq(da.sqrt(a), np.sqrt(x))
    assert eq(da.square(a), np.square(x))

    assert eq(da.sin(a), np.sin(x))
    assert eq(da.cos(b), np.cos(y))
    assert eq(da.tan(a), np.tan(x))
    assert eq(da.arcsin(b / 10), np.arcsin(y / 10))
    assert eq(da.arccos(b / 10), np.arccos(y / 10))
    assert eq(da.arctan(b / 10), np.arctan(y / 10))
    assert eq(da.arctan2(b * 10, a), np.arctan2(y * 10, x))
    assert eq(da.hypot(b, a), np.hypot(y, x))
    assert eq(da.sinh(a), np.sinh(x))
    assert eq(da.cosh(b), np.cosh(y))
    assert eq(da.tanh(a), np.tanh(x))
    assert eq(da.arcsinh(b * 10), np.arcsinh(y * 10))
    assert eq(da.arccosh(b * 10), np.arccosh(y * 10))
    assert eq(da.arctanh(b / 10), np.arctanh(y / 10))
    assert eq(da.deg2rad(a), np.deg2rad(x))
    assert eq(da.rad2deg(a), np.rad2deg(x))

    assert eq(da.logical_and(a < 1, b < 4), np.logical_and(x < 1, y < 4))
    assert eq(da.logical_or(a < 1, b < 4), np.logical_or(x < 1, y < 4))
    assert eq(da.logical_xor(a < 1, b < 4), np.logical_xor(x < 1, y < 4))
    assert eq(da.logical_not(a < 1), np.logical_not(x < 1))
    assert eq(da.maximum(a, 5 - a), np.maximum(a, 5 - a))
    assert eq(da.minimum(a, 5 - a), np.minimum(a, 5 - a))
    assert eq(da.fmax(a, 5 - a), np.fmax(a, 5 - a))
    assert eq(da.fmin(a, 5 - a), np.fmin(a, 5 - a))

    assert eq(da.isreal(a + 1j * b), np.isreal(x + 1j * y))
    assert eq(da.iscomplex(a + 1j * b), np.iscomplex(x + 1j * y))
    assert eq(da.isfinite(a), np.isfinite(x))
    assert eq(da.isinf(a), np.isinf(x))
    assert eq(da.isnan(a), np.isnan(x))
    assert eq(da.signbit(a - 3), np.signbit(x - 3))
    assert eq(da.copysign(a - 3, b), np.copysign(x - 3, y))
    assert eq(da.nextafter(a - 3, b), np.nextafter(x - 3, y))
    assert eq(da.ldexp(c, c), np.ldexp(z, z))
    assert eq(da.fmod(a * 12, b), np.fmod(x * 12, y))
    assert eq(da.floor(a * 0.5), np.floor(x * 0.5))
    assert eq(da.ceil(a), np.ceil(x))
    assert eq(da.trunc(a / 2), np.trunc(x / 2))

    assert eq(da.degrees(b), np.degrees(y))
    assert eq(da.radians(a), np.radians(x))

    assert eq(da.rint(a + 0.3), np.rint(x + 0.3))
    assert eq(da.fix(a - 2.5), np.fix(x - 2.5))

    assert eq(da.angle(a + 1j), np.angle(x + 1j))
    assert eq(da.real(a + 1j), np.real(x + 1j))
    assert eq((a + 1j).real, np.real(x + 1j))
    assert eq(da.imag(a + 1j), np.imag(x + 1j))
    assert eq((a + 1j).imag, np.imag(x + 1j))
    assert eq(da.conj(a + 1j * b), np.conj(x + 1j * y))
    assert eq((a + 1j * b).conj(), (x + 1j * y).conj())

    assert eq(da.clip(b, 1, 4), np.clip(y, 1, 4))
    assert eq(da.fabs(b), np.fabs(y))
    assert eq(da.sign(b - 2), np.sign(y - 2))

    l1, l2 = da.frexp(a)
    r1, r2 = np.frexp(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    l1, l2 = da.modf(a)
    r1, r2 = np.modf(x)
    assert eq(l1, r1)
    assert eq(l2, r2)

    assert eq(da.around(a, -1), np.around(x, -1))
예제 #10
0
def gradient(image):
    return da.where(
            da.fabs(image) <= huber['threshold'],
            2 * image,
            2 * huber['threshold'] * da.sign(image))
예제 #11
0
def main():
    ########################################################
    
    infile = '/data/ERA5/hdf5_hr/ds_eval_2000_to_2002.hdf5'
    outfile = 'sr_eval_2000_2002.{}.tfrecords'
    n_files = 8
    gzip = True
    shuffle = False

    ########################################################

    SR_ratio = 4
    log1p_norm = True
    z_norm = False
    
    mean, std = [2.0152406e-08, 2.1581373e-07], [2.8560082e-05, 5.0738556e-05]
    mean_log1p, std_log1p = [0.008315503, 0.0028762482], [0.5266841, 0.5418187]

    #########################################################

    HR_reduce_latitude = 0#107  # 15 deg from each pole
    patch_size = None#(96, 96)
    n_patches = 50
    mode = 'train'

    ########################################################

    f = h5py.File(infile, 'r', rdcc_nbytes=1000*1000*1000)
    data = da.from_array(f['data'], chunks=(-1, 16 if shuffle else 256, -1, -1))  # CNHW layout
    data = da.transpose(data, (1,2,3,0))
    dtype = data.dtype

    if dtype != np.float32:
        print('WARNING: data will be saved as float32 but input ist float64!')

    if mean is None:
        arr = data
        with ProgressBar():
            mean, std = da.compute(arr.mean(axis=[0,1,2]), arr.std(axis=[0,1,2]), num_workers=8)
    else:
        mean, std = np.asarray(mean, dtype=dtype), np.asarray(std, dtype=dtype)

    print('mean: {}, std: {}'.format(list(mean), list(std)))

    if log1p_norm:
        data_z_norm = (data-mean) / std
        data_log1p = da.sign(data_z_norm) * da.log1p(da.fabs(data_z_norm))

        if mean_log1p is None:
            arr = data_log1p
            with ProgressBar():
                mean_log1p, std_log1p = da.compute(arr.mean(axis=[0,1,2]), arr.std(axis=[0,1,2]), num_workers=8)
        else:
            mean_log1p, std_log1p = np.asarray(mean_log1p, dtype=dtype), np.asarray(std_log1p, dtype=dtype)

        print('mean_log1p: {}, std_log1p: {}'.format(list(mean_log1p), list(std_log1p)))

        data = data_log1p
    elif z_norm:
        data = (data-mean) / std

    if shuffle:
        block_indices = np.random.permutation(data.numblocks[0])
    else:
        block_indices = np.arange(data.numblocks[0])


    file_blocks = np.array_split(block_indices, n_files)
    i = 0
    for n, indices in enumerate(file_blocks):
        if n_files > 1:
            name = outfile.format(n)
        else:
            name = outfile
        
        with tf.io.TFRecordWriter(name, options='ZLIB' if gzip else None) as writer:
            for block_idx in indices:
                block =  data.blocks[block_idx].compute()
                if shuffle:
                    block = np.random.permutation(block)

                if HR_reduce_latitude:
                    lat_start = HR_reduce_latitude//2
                    block = block[:, lat_start:(-lat_start),:, :]

                generate_TFRecords(writer, block, SR_ratio, mode, patch_size, n_patches)
                i += 1
                print('{} / {}'.format(i, data.numblocks[0]), flush=True)