Exemple #1
0
def jacobian_cplx(
        forward_fn: Callable,
        params: PyTree,
        samples: Array,
        chunk_size: int = None,
        _build_fn: Callable = partial(jax.tree_map, jax.lax.complex),
) -> PyTree:
    """Calculates Jacobian entries by vmapping grad.
    Assumes the function is R→C, backpropagates 1 and -1j

    Args:
        forward_fn: the log wavefunction ln Ψ
        params : a pytree of parameters p
        samples : an array of n samples σ

    Returns:
        The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree
    """
    def _jacobian_cplx(forward_fn, params, samples, _build_fn):
        y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples)
        gr, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y)))
        gi, _ = vjp_fun(np.array(-1.0j, dtype=jnp.result_type(y)))
        return _build_fn(gr, gi)

    return vmap_chunked(_jacobian_cplx,
                        in_axes=(None, None, 0, None),
                        chunk_size=chunk_size)(forward_fn, params, samples,
                                               _build_fn)
Exemple #2
0
def local_value_kernel_chunked(
    logpsi: Callable,
    pars: PyTree,
    σ: Array,
    args: PyTree,
    *,
    chunk_size: Optional[int] = None,
):
    """
    local_value kernel for MCState and generic operators
    """
    σp, mels = args

    if jnp.ndim(σp) != 3:
        σp = σp.reshape((σ.shape[0], -1, σ.shape[-1]))
        mels = mels.reshape(σp.shape[:-1])

    logpsi_chunked = nkjax.vmap_chunked(
        partial(logpsi, pars), in_axes=0, chunk_size=chunk_size
    )
    N = σ.shape[-1]

    logpsi_σ = logpsi_chunked(σ.reshape((-1, N))).reshape(σ.shape[:-1] + (1,))
    logpsi_σp = logpsi_chunked(σp.reshape((-1, N))).reshape(σp.shape[:-1])

    return jnp.sum(mels * jnp.exp(logpsi_σp - logpsi_σ), axis=-1)
Exemple #3
0
def jacobian_real_holo(forward_fn: Callable,
                       params: PyTree,
                       samples: Array,
                       chunk_size: int = None) -> PyTree:
    """Calculates Jacobian entries by vmapping grad.
    Assumes the function is R→R or holomorphic C→C, so single grad is enough

    Args:
        forward_fn: the log wavefunction ln Ψ
        params : a pytree of parameters p
        samples : an array of n samples σ

    Returns:
        The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree
    """
    def _jacobian_real_holo(forward_fn, params, samples):
        y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples)
        res, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y)))
        return res

    return vmap_chunked(_jacobian_real_holo,
                        in_axes=(None, None, 0),
                        chunk_size=chunk_size)(forward_fn, params, samples)
Exemple #4
0
def prepare_centered_oks(
    apply_fun: Callable,
    params: PyTree,
    samples: Array,
    model_state: Optional[PyTree],
    mode: str,
    rescale_shift: bool,
    chunk_size: int = None,
) -> PyTree:
    """
    compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩
    divided by √n

    In a somewhat intransparent way this also internally splits all parameters to real
    in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ
    which is only compatible with split-to-real pytree vectors

    Args:
        apply_fun: The forward pass of the Ansatz
        params : a pytree of parameters p
        samples : an array of (n in total) batched samples σ
        model_state: untrained state parameters of the model
        mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic'
        rescale_shift: whether scale-invariant regularisation should be used (default: True)
        chunk_size: an int specfying the size of the chunks degradient should be computed in (default: None)

    Returns:
        if not rescale_shift:
            a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n;
            None
        else:
            the same pytree, but the entries for each parameter normalised to unit norm;
            pytree containing the norms that were divided out (same shape as params)

    """
    # un-batch the samples
    samples = samples.reshape((-1, samples.shape[-1]))

    # pre-apply the model state
    def forward_fn(W, σ):
        return apply_fun({"params": W, **model_state}, σ)

    if mode == "real":
        split_complex_params = True  # convert C→R and R&C→R to R→R
        jacobian_fun = dense_jacobian_real_holo
    elif mode == "complex":
        split_complex_params = True  # convert C→C and R&C→C to R→C
        # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx)

        # avoid converting to complex and then back
        # by passing around the oks as a tuple of two pytrees representing the real and imag parts
        jacobian_fun = dense_jacobian_cplx
    elif mode == "holomorphic":
        split_complex_params = False
        jacobian_fun = dense_jacobian_real_holo
    else:
        raise NotImplementedError(
            'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'.format(
                mode
            )
        )

    # Stored as contiguous real stacked on top of contiguous imaginary (SOA)
    if split_complex_params:
        # doesn't do anything if the params are already real
        params, reassemble = tree_to_reim(params)

        def f(W, σ):
            return forward_fn(reassemble(W), σ)

    else:
        f = forward_fn

    def gradf_fun(params, σ):
        gradf_dense = jacobian_fun(f, params, σ)
        return gradf_dense

    jacobians = nkjax.vmap_chunked(gradf_fun, in_axes=(None, 0), chunk_size=chunk_size)(
        params, samples
    )

    n_samp = samples.shape[0] * mpi.n_nodes
    centered_oks = subtract_mean(jacobians, axis=0) / np.sqrt(
        n_samp, dtype=jacobians.dtype
    )

    centered_oks = centered_oks.reshape(-1, centered_oks.shape[-1])

    if rescale_shift:
        return _rescale(centered_oks)
    else:
        return centered_oks, None
Exemple #5
0
def _local_continuous_kernel(kernel, logpsi, pars, σ, args, *, chunk_size=None):
    def _kernel(σ):
        return kernel(logpsi, pars, σ, args)

    return nkjax.vmap_chunked(_kernel, in_axes=0, chunk_size=chunk_size)(σ)