예제 #1
0
def test_R_hat():
    # detect disagreeing chains
    x = np.array(
        [
            [1.0, 1.0, 1.0],
            [1.1, 1.1, 1.1],
        ]
    )
    assert statistics(x).R_hat > 1.01

    # detect non-stationary chains
    x = np.array(
        [
            [1.0, 1.5, 2.0],
            [2.0, 1.5, 1.0],
        ]
    )
    assert statistics(x).R_hat > 1.01

    # detect "stuck" chains
    x = np.array(
        [
            np.random.normal(size=1000),
            np.random.normal(size=1000),
        ]
    )
    # not stuck -> good R_hat:
    assert statistics(x).R_hat <= 1.01
    # stuck -> bad  R_hat:
    x[1, 100:] = 1.0
    assert statistics(x).R_hat > 1.01
예제 #2
0
파일: test_stats.py 프로젝트: FermiQ/netket
def _test_stats_mean_std(hi, ham, ma, n_chains):
    w = ma.init(jax.random.PRNGKey(WEIGHT_SEED * n_chains), jnp.zeros((1, hi.size)))

    sampler = nk.sampler.MetropolisLocal(hi, n_chains=n_chains)

    n_samples = 16000
    num_samples_per_chain = n_samples // n_chains

    # Discard a few samples
    _, state = sampler.sample(ma, w, chain_length=1000)

    samples, state = sampler.sample(
        ma, w, chain_length=num_samples_per_chain, state=state
    )
    assert samples.shape == (num_samples_per_chain, n_chains, hi.size)

    # eloc = np.empty((num_samples_per_chain, n_chains), dtype=np.complex128)
    eloc = local_values(ma.apply, w, ham, samples)
    assert eloc.shape == (num_samples_per_chain, n_chains)

    stats = statistics(eloc.T)

    assert stats.mean == pytest.approx(np.mean(eloc))
    if n_chains > 1:

        # variance == average sample variance over chains
        assert stats.variance == pytest.approx(np.var(eloc))
        # R estimate
        B_over_n = stats.error_of_mean ** 2
        W = stats.variance
        assert stats.R_hat == pytest.approx(np.sqrt(1.0 + B_over_n / W), abs=1e-3)
예제 #3
0
def _test_stats_mean_std(hi, ham, ma, n_chains):
    sampler = nk.sampler.MetropolisLocal(ma, n_chains=n_chains)

    n_samples = 16000
    num_samples_per_chain = n_samples // n_chains

    # Discard a few samples
    sampler.generate_samples(1000)

    samples = sampler.generate_samples(num_samples_per_chain)
    assert samples.shape == (num_samples_per_chain, n_chains, hi.size)

    eloc = np.empty((num_samples_per_chain, n_chains), dtype=np.complex128)
    for i in range(num_samples_per_chain):
        eloc[i] = local_values(ham, ma, samples[i])

    stats = statistics(eloc.T)

    # These tests only work for one MPI process
    assert nk.stats.MPI.COMM_WORLD.size == 1

    assert stats.mean == pytest.approx(np.mean(eloc))
    if n_chains > 1:

        # variance == average sample variance over chains
        assert stats.variance == pytest.approx(np.var(eloc))
        # R estimate
        B_over_n = stats.error_of_mean ** 2
        W = stats.variance
        assert stats.R_hat == pytest.approx(np.sqrt(1.0 + B_over_n / W), abs=1e-3)
예제 #4
0
def _test_stats_mean_std(hi, ham, ma, n_chains):
    sampler = nk.sampler.MetropolisLocal(ma, n_chains=n_chains)

    n_samples = 16000
    num_samples_per_chain = n_samples // n_chains

    # Discard a few samples
    sampler.generate_samples(1000)

    samples = sampler.generate_samples(num_samples_per_chain)
    assert samples.shape == (num_samples_per_chain, n_chains, hi.size)

    eloc = local_values(ham, ma, samples)
    assert eloc.shape == (num_samples_per_chain, n_chains)

    stats = statistics(eloc)

    # These tests only work for one MPI process
    assert nk.MPI.size() == 1

    assert stats.mean == pytest.approx(np.mean(eloc))
    if n_chains > 1:
        # error of mean == stdev of sample mean between chains / sqrt(#chains)
        assert stats.error_of_mean == pytest.approx(
            eloc.mean(axis=0).std(ddof=0) / np.sqrt(n_chains)
        )
        # variance == average sample variance over chains
        assert stats.variance == pytest.approx(eloc.var(axis=0).mean())
        # R estimate
        B_over_n = stats.error_of_mean ** 2
        W = stats.variance
        assert stats.R == pytest.approx(
            np.sqrt((n_samples - 1.0) / n_samples + B_over_n / W), abs=1e-3
        )
예제 #5
0
def grad_expect_hermitian_chunked(
    chunk_size: int,
    local_value_kernel_chunked: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel_chunked(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
        chunk_size=chunk_size,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        vjp_fun_chunked = nkjax.vjp_chunked(
            lambda w, σ: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            σ,
            conjugate=True,
            chunk_size=chunk_size,
            chunk_argnums=1,
            nondiff_argnums=1,
        )
        new_model_state = None
    else:
        raise NotImplementedError

    Ō_grad = vjp_fun_chunked(
        (jnp.conjugate(O_loc) / n_samples),
    )[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
예제 #6
0
def grad_expect_hermitian(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * utils.n_nodes

    O_loc = local_cost_function(
        local_value_cost,
        model_apply_fun,
        {"params": parameters, **model_state},
        σp,
        mels,
        σ,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        _, vjp_fun = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            conjugate=True,
        )
        new_model_state = None
    else:
        _, vjp_fun, new_model_state = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
            parameters,
            conjugate=True,
            has_aux=True,
        )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
예제 #7
0
def grad_expect_hermitian(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    is_mutable = mutable is not False
    _, vjp_fun, *new_model_state = nkjax.vjp(
        lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
        parameters,
        conjugate=True,
        has_aux=is_mutable,
    )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    new_model_state = new_model_state[0] if is_mutable else None

    return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
예제 #8
0
    def expect_operator(self, Ô: AbstractOperator) -> Stats:
        σ = self.diagonal.samples
        σ_shape = σ.shape
        σ = σ.reshape((-1, σ.shape[-1]))

        σ_np = np.asarray(σ)
        σp, mels = Ô.get_conn_padded(σ_np)

        # now we have to concatenate the two
        O_loc = local_cost_function(
            local_value_op_op_cost,
            self._apply_fun,
            self.variables,
            σp,
            mels,
            σ,
        ).reshape(σ_shape[:-1])

        # notice that loc.T is passed to statistics, since that function assumes
        # that the first index is the batch index.
        return statistics(O_loc.T)
예제 #9
0
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    if not has_aux:
        out_axes = (0, 0)
    else:
        out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    local_kernel_vmap = jax.vmap(
        partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = netket.operator._der_local_values_jax._local_values_and_grads_notcentered_kernel(
        logpsi, parameters, σp, mels, σ
    )
    # netket.operator._der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    # old implementation
    # this is faster, even though i think the one below should be faster
    # (this works, but... yeah. let's keep it here and delete in a while.)
    grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0)
    der_logs = grad_fun(parameters, σ)
    der_logs_ave = tree_map(lambda x: mean(x, axis=0), der_logs)

    # TODO
    # NEW IMPLEMENTATION
    # This should be faster, but should benchmark as it seems slower
    # to compute der_logs_ave i can just do a jvp with a ones vector
    # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    # der_logs_ave = d_logpsi(
    #    jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes)
    # )[0]
    der_logs_ave = tree_map(sum_inplace, der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )
예제 #10
0
파일: test_stats.py 프로젝트: FermiQ/netket
def _test_tau_corr(batch_size, sig_corr):
    def next_pow_two(n):
        i = 1
        while i < n:
            i = i << 1
        return i

    def autocorr_func_1d(x, norm=True):
        x = np.atleast_1d(x)
        if len(x.shape) != 1:
            raise ValueError("invalid dimensions for 1D autocorrelation function")
        n = next_pow_two(len(x))

        # Compute the FFT and then (from that) the auto-correlation function
        f = np.fft.fft(x - np.mean(x), n=2 * n)
        acf = np.fft.ifft(f * np.conjugate(f))[: len(x)].real
        acf /= 4 * n

        # Optionally normalize
        if norm:
            acf /= acf[0]

        return acf

    @jit
    def gen_data(n_samples, log_f, dx, seed=1234):
        np.random.seed(seed)
        # Generates data with a simple markov chain
        x = np.empty(n_samples)
        x_old = np.random.normal()
        for i in range(n_samples):
            x_new = x_old + np.random.normal(scale=dx, loc=0.0)
            if np.exp(log_f(x_new) - log_f(x_old)) > np.random.uniform(0, 1):
                x[i] = x_new
            else:
                x[i] = x_old

            x_old = x[i]
        return x

    @jit
    def log_f(x):
        return -(x ** 2.0) / 2.0

    def func_corr(x, tau):
        return np.exp(-x / (tau))

    n_samples = 8000000 // batch_size

    data = np.empty((batch_size, n_samples))
    tau_fit = np.empty((batch_size))

    for i in range(batch_size):
        data[i] = gen_data(n_samples, log_f, sig_corr, seed=i + batch_size)
        autoc = autocorr_func_1d(data[i])
        popt, pcov = curve_fit(func_corr, np.arange(40), autoc[0:40])
        tau_fit[i] = popt[0]

    tau_fit_m = tau_fit.mean()

    stats = statistics(data)

    assert np.mean(data) == pytest.approx(stats.mean)
    assert np.var(data) == pytest.approx(stats.variance)

    assert tau_fit_m == pytest.approx(stats.tau_corr, rel=1, abs=3)

    eom_fit = np.sqrt(np.var(data) * tau_fit_m / float(n_samples * batch_size))

    print(stats.error_of_mean, eom_fit)
    assert eom_fit == pytest.approx(stats.error_of_mean, rel=0.6)
예제 #11
0
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ)
    # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    der_logs_ave = d_logpsi(
        jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes)
    )[0]
    der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away
    # the imaginary part of the gradient.
    # we do this also for standard gradient of energy.
    # this avoid errors in #867, #789, #850
    LdagL_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        LdagL_grad,
        parameters,
    )

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )