Esempio n. 1
0
def default_qgt_matrix(variational_state, solver=False, **kwargs):
    """
    Determines default metric tensor depending on variational_state and sovler
    """
    from netket.vqs import ExactState

    if isinstance(variational_state, ExactState):
        return partial(QGTJacobianPyTree, **kwargs)

    n_param_leaves = len(jax.tree_leaves(variational_state.parameters))
    n_params = variational_state.n_parameters

    # those require dense matrix that is known to be faster for this qgt
    if _is_dense_solver(solver):
        return partial(QGTJacobianDense, **kwargs)

    # arbitrary heuristic: if the network's parameters has many leaves
    # (an rbm has 3) then JacobianDense might be faster
    # the numbers chosen below are rather arbitrary and should be tuned.
    if n_param_leaves > 6 and n_params > 800:
        if nkjax.tree_ishomogeneous(variational_state.variables):
            return partial(QGTJacobianDense, **kwargs)
        else:
            return partial(QGTJacobianPyTree, **kwargs)
    else:
        return partial(QGTOnTheFly, **kwargs)
Esempio n. 2
0
def O_mean(forward_fn, params, samples, holomorphic=True):
    r"""
    compute \langle O \rangle
    i.e. the mean of the rows of the jacobian of forward_fn
    """

    # determine the output type of the forward pass
    dtype = jax.eval_shape(forward_fn, params, samples).dtype
    w = jnp.ones(samples.shape[0],
                 dtype=dtype) * (1.0 / (samples.shape[0] * mpi.n_nodes))

    homogeneous = nkjax.tree_ishomogeneous(params)
    real_params = not nkjax.tree_leaf_iscomplex(params)
    real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params,
                                                   samples))

    if homogeneous and (real_params or holomorphic):
        if real_params and not real_out:
            # R->C
            return O_vjp_rc(forward_fn, params, samples, w)
        else:
            # R->R and holomorphic C->C
            return O_vjp(forward_fn, params, samples, w)
    else:
        # R&C -> C
        # non-holomorphic
        # C->R
        assert False
Esempio n. 3
0
def DeltaOdagger_DeltaO_v(forward_fn, params, samples, v, holomorphic=True):
    r"""
    compute \langle \Delta O^\dagger \Delta O \rangle v

    where \Delta O = O - \langle O \rangle
    """

    homogeneous = nkjax.tree_ishomogeneous(params)
    real_params = not nkjax.tree_leaf_iscomplex(params)
    #  real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples))

    if not (homogeneous and (real_params or holomorphic)):
        # everything except R->R, holomorphic C->C and R->C
        params, reassemble = nkjax.tree_to_real(params)
        v, _ = nkjax.tree_to_real(v)
        _forward_fn = forward_fn

        def forward_fn(p, x):
            return _forward_fn(reassemble(p), x)

    omean = O_mean(forward_fn, params, samples, holomorphic=holomorphic)

    def forward_fn_centered(p, x):
        return forward_fn(p, x) - tree_dot(p, omean)

    res = Odagger_O_v(forward_fn_centered, params, samples, v)

    if not (homogeneous and (real_params or holomorphic)):
        res = reassemble(res)
    return res
Esempio n. 4
0
def _choose_jacobian_mode(apply_fun, pars, model_state, samples, mode, holomorphic):
    homogeneous_vars = nkjax.tree_ishomogeneous(pars)

    if holomorphic is True:
        if not homogeneous_vars:
            warnings.warn(
                dedent(
                    """The ansatz has non homogeneous variables, which might not behave well with the
                       holomorhic implemnetation.
                       Use `holomorphic=False` or mode='complex' for more accurate results but
                       lower performance.
                    """
                )
            )
        mode = "holomorphic"
    else:
        leaf_iscomplex = nkjax.tree_leaf_iscomplex(pars)
        complex_output = nkjax.is_complex(
            jax.eval_shape(
                apply_fun,
                {"params": pars, **model_state},
                samples.reshape(-1, samples.shape[-1]),
            )
        )

        if complex_output:
            if leaf_iscomplex:
                if holomorphic is None:
                    warnings.warn(
                        dedent(
                            """
                                Complex-to-Complex model detected. Defaulting to `holomorphic=False` for
                                the implementation of QGTJacobianDense.
                                If your model is holomorphic, specify `holomorphic=True` to use a more
                                performant implementation.
                                To suppress this warning specify `holomorphic`.
                                """
                        ),
                        UserWarning,
                    )
                mode = "complex"
            else:
                mode = "complex"
        else:
            mode = "real"

    if mode == "real":
        return 0
    elif mode == "complex":
        return 1
    elif mode == "holomorphic":
        return 2
    else:
        raise ValueError(f"unknown mode {mode}")