コード例 #1
0
def grad_expect_hermitian(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * utils.n_nodes

    O_loc = local_cost_function(
        local_value_cost,
        model_apply_fun,
        {"params": parameters, **model_state},
        σp,
        mels,
        σ,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        _, vjp_fun = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            conjugate=True,
        )
        new_model_state = None
    else:
        _, vjp_fun, new_model_state = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
            parameters,
            conjugate=True,
            has_aux=True,
        )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
コード例 #2
0
def grad_expect_operator_kernel(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    machine_pow: int,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree, Stats]:

    if not config.FLAGS["NETKET_EXPERIMENTAL"]:
        raise RuntimeError(
            """
                           Computing the gradient of a squared or non hermitian
                           operator is an experimental feature under development
                           and is known not to return wrong values sometimes.

                           If you want to debug it, set the environment variable
                           NETKET_EXPERIMENTAL=1
                           """
        )

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    is_mutable = mutable is not False
    logpsi = lambda w, σ: model_apply_fun(
        {"params": w, **model_state}, σ, mutable=mutable
    )
    log_pdf = (
        lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real
    )

    def expect_closure_pars(pars):
        return nkjax.expect(
            log_pdf,
            partial(local_value_kernel, logpsi),
            pars,
            σ,
            local_value_args,
            n_chains=σ_shape[0],
        )

    Ō, Ō_pb, Ō_stats = nkjax.vjp(expect_closure_pars, parameters, has_aux=True)
    Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))[0]

    if is_mutable:
        raise NotImplementedError(
            "gradient of non-hermitian operators over mutable models "
            "is not yet implemented."
        )
    new_model_state = None

    return (
        Ō_stats,
        jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], Ō_pars_grad),
        new_model_state,
    )
コード例 #3
0
def grad_expect_operator_kernel(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    machine_pow: int,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree, Stats]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    is_mutable = mutable is not False
    logpsi = lambda w, σ: model_apply_fun(
        {"params": w, **model_state}, σ, mutable=mutable
    )
    log_pdf = (
        lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real
    )

    def expect_closure_pars(pars):
        return nkjax.expect(
            log_pdf,
            partial(local_value_kernel, logpsi),
            pars,
            σ,
            local_value_args,
            n_chains=σ_shape[0],
        )

    Ō, Ō_pb, Ō_stats = nkjax.vjp(
        expect_closure_pars, parameters, has_aux=True, conjugate=True
    )
    Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))[0]

    # This term below is needed otherwise it does not match the value obtained by
    # (ha@ha).collect(). I'm unsure of why it is needed.
    Ō_pars_grad = jax.tree_multimap(
        lambda x, target: x / 2 if jnp.iscomplexobj(target) else x,
        Ō_pars_grad,
        parameters,
    )

    if is_mutable:
        raise NotImplementedError(
            "gradient of non-hermitian operators over mutable models "
            "is not yet implemented."
        )
    new_model_state = None

    return (
        Ō_stats,
        jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], Ō_pars_grad),
        new_model_state,
    )
コード例 #4
0
ファイル: expect_grad.py プロジェクト: tobiaswiener/netket
def _local_values_and_grads_notcentered_kernel(logpsi, pars, vp, mel, v):
    logpsi_vp, f_vjp = nkjax.vjp(lambda w: logpsi(w, vp), pars, conjugate=False)
    vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v))

    # TODO : here someone must bring order to those multiple conjs
    odtype = jax.eval_shape(logpsi, pars, v).dtype
    vec = jnp.asarray(jnp.conjugate(vec), dtype=odtype)
    loc_val = vec.sum()
    grad_c = f_vjp(vec.conj())[0]

    return loc_val, grad_c
コード例 #5
0
def grad_expect_hermitian(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    is_mutable = mutable is not False
    _, vjp_fun, *new_model_state = nkjax.vjp(
        lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
        parameters,
        conjugate=True,
        has_aux=is_mutable,
    )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    new_model_state = new_model_state[0] if is_mutable else None

    return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
コード例 #6
0
def tree_ravel(pytree: PyTree) -> Tuple[jnp.ndarray, Callable]:
    """Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

    Args:
      pytree: a pytree to ravel

    Returns:
      A pair where the first element is a 1D array representing the flattened and
      concatenated leaf values, and the second element is a callable for
      unflattening a 1D vector of the same length back to a pytree of of the same
      structure as the input ``pytree``.
    """
    leaves, treedef = tree_flatten(pytree)
    flat, unravel_list = nkjax.vjp(_ravel_list, *leaves)
    unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
    return flat, unravel_pytree
コード例 #7
0
def _exp_grad(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    OΨ: jnp.ndarray,
    Ψ: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:
    is_mutable = mutable is not False

    expval_O = (Ψ.conj() * OΨ).sum()
    ΔOΨ = (OΨ - expval_O * Ψ.conj()) * Ψ

    _, vjp_fun, *new_model_state = nkjax.vjp(
        lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
        parameters,
        conjugate=True,
        has_aux=is_mutable,
    )

    Ō_grad = vjp_fun(ΔOΨ)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    new_model_state = new_model_state[0] if is_mutable else None

    return (
        None,
        Ō_grad,
        expval_O,
        new_model_state,
    )
コード例 #8
0
def grad_expect_operator_kernel(
    machine_pow: int,
    model_apply_fun: Callable,
    local_kernel: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:

    if not config.FLAGS["NETKET_EXPERIMENTAL"]:
        raise RuntimeError(
            """
                           Computing the gradient of a squared or non hermitian 
                           operator is an experimental feature under development 
                           and is known not to return wrong values sometimes.

                           If you want to debug it, set the environment variable
                           NETKET_EXPERIMENTAL=1
                           """
        )

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    has_aux = mutable is not False
    if not has_aux:
        out_axes = (0, 0)
    else:
        out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    log_pdf = (
        lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real
    )

    def expect_closure(*args):
        local_kernel_vmap = jax.vmap(
            partial(local_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
        )

        return nkjax.expect(log_pdf, local_kernel_vmap, *args, n_chains=σ_shape[0])

    def expect_closure_pars(pars):
        return expect_closure(pars, σ, σp, mels)

    Ō, Ō_pb, Ō_stats = nkjax.vjp(expect_closure_pars, parameters, has_aux=True)
    Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))

    return (
        Ō_stats,
        tree_map(lambda x: sum_inplace(x) / utils.n_nodes, Ō_pars_grad),
        model_state,
    )
コード例 #9
0
ファイル: expect_grad.py プロジェクト: tobiaswiener/netket
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ)
    # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    der_logs_ave = d_logpsi(
        jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes)
    )[0]
    der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away
    # the imaginary part of the gradient.
    # we do this also for standard gradient of energy.
    # this avoid errors in #867, #789, #850
    LdagL_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        LdagL_grad,
        parameters,
    )

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )