Ejemplo n.º 1
0
    re_dense = ravel(re)
    im_dense = ravel(im)
    res = jnp.stack([re_dense, im_dense], axis=0)

    return res


def ravel(x: PyTree) -> Array:
    """
    shorthand for tree_ravel
    """
    dense, _ = nkjax.tree_ravel(x)
    return dense


dense_jacobian_real_holo = nkjax.compose(ravel, jacobian_real_holo)
dense_jacobian_cplx = nkjax.compose(
    stack_jacobian_tuple, partial(jacobian_cplx, _build_fn=lambda *x: x)
)


def _rescale(centered_oks):
    """
    compute ΔOₖ/√Sₖₖ and √Sₖₖ
    to do scale-invariant regularization (Becca & Sorella 2017, pp. 143)
    Sₖₗ/(√Sₖₖ√Sₗₗ) = ΔOₖᴴΔOₗ/(√Sₖₖ√Sₗₗ) = (ΔOₖ/√Sₖₖ)ᴴ(ΔOₗ/√Sₗₗ)
    """
    scale = (
        mpi.mpi_sum_jax(
            jnp.sum((centered_oks * centered_oks.conj()).real, axis=0, keepdims=True)
        )[0]
Ejemplo n.º 2
0
def prepare_centered_oks(
    apply_fun: Callable,
    params: PyTree,
    samples: Array,
    model_state: Optional[PyTree],
    mode: str,
    rescale_shift: bool,
    pdf=None,
    chunk_size: int = None,
) -> PyTree:
    """
    compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩
    divided by √n

    In a somewhat intransparent way this also internally splits all parameters to real
    in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ
    which is only compatible with split-to-real pytree vectors

    Args:
        apply_fun: The forward pass of the Ansatz
        params : a pytree of parameters p
        samples : an array of (n in total) batched samples σ
        model_state: untrained state parameters of the model
        mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic'
        rescale_shift: whether scale-invariant regularisation should be used (default: True)
        pdf: |ψ(x)|^2 if exact optimization is being used else None
        chunk_size: an int specifying the size of the chunks the gradient should be computed in (default: None)

    Returns:
        if not rescale_shift:
            a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n;
            None
        else:
            the same pytree, but the entries for each parameter normalised to unit norm;
            pytree containing the norms that were divided out (same shape as params)

    """
    # un-batch the samples
    samples = samples.reshape((-1, samples.shape[-1]))

    # pre-apply the model state
    def forward_fn(W, σ):
        return apply_fun({"params": W, **model_state}, σ)

    if mode == "real":
        split_complex_params = True  # convert C→R and R&C→R to R→R
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    elif mode == "complex":
        split_complex_params = True  # convert C→C and R&C→C to R→C
        # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx)

        # avoid converting to complex and then back
        # by passing around the oks as a tuple of two pytrees representing the real and imag parts
        centered_jacobian_fun = compose(
            stack_jacobian_tuple,
            partial(centered_jacobian_cplx, _build_fn=lambda *x: x),
        )
        jacobian_fun = jacobian_cplx
    elif mode == "holomorphic":
        split_complex_params = False
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    else:
        raise NotImplementedError(
            'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'
            .format(mode))

    if split_complex_params:
        # doesn't do anything if the params are already real
        params, reassemble = tree_to_real(params)

        def f(W, σ):
            return forward_fn(reassemble(W), σ)

    else:
        f = forward_fn

    if pdf is None:
        centered_oks = _divide_by_sqrt_n_samp(
            centered_jacobian_fun(
                f,
                params,
                samples,
                chunk_size=chunk_size,
            ),
            samples,
        )
    else:
        oks = jacobian_fun(f, params, samples)
        oks_mean = jax.tree_map(partial(sum, axis=0),
                                _multiply_by_pdf(oks, pdf))
        centered_oks = jax.tree_map(lambda x, y: x - y, oks, oks_mean)

        centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf))
    if rescale_shift:
        return _rescale(centered_oks)
    else:
        return centered_oks, None
Ejemplo n.º 3
0
        partial(_vjp, nondiff_argnums=nondiff_argnums, conjugate=conjugate),
        scan_fun=scan_fun,
        argnums=argnums,
    )(fun, cotangents, *primals)

    return _multimap(lambda c, l: _tree_unchunk(l)
                     if c else l, append_cond, res)


def _gen_append_cond_vjp(primals, nondiff_argnums, chunk_argnums):
    diff_argnums = filter(lambda i: i not in nondiff_argnums,
                          range(len(primals)))
    return tuple(map(lambda i: i in chunk_argnums, diff_argnums))


_gen_append_cond_value_vjp = compose(lambda t: (True, ) + t,
                                     _gen_append_cond_vjp)

_vjp_fun_chunked = partial(
    __vjp_fun_chunked,
    _vjp=compose(lambda yr: yr[1:], _vjp),
    _append_cond_fun=_gen_append_cond_vjp,
)
_value_and_vjp_fun_chunked = compose(
    lambda yr: (yr[0], yr[1:]),
    partial(__vjp_fun_chunked,
            _vjp=_vjp,
            _append_cond_fun=_gen_append_cond_value_vjp),
)


def vjp_chunked(
Ejemplo n.º 4
0
    Returns:
        The Jacobian matrix ∂/∂pₖ ln Ψ(σⱼ) as a PyTree
    """
    def _jacobian_cplx(forward_fn, params, samples, _build_fn):
        y, vjp_fun = jax.vjp(single_sample(forward_fn), params, samples)
        gr, _ = vjp_fun(np.array(1.0, dtype=jnp.result_type(y)))
        gi, _ = vjp_fun(np.array(-1.0j, dtype=jnp.result_type(y)))
        return _build_fn(gr, gi)

    return vmap_chunked(_jacobian_cplx,
                        in_axes=(None, None, 0, None),
                        chunk_size=chunk_size)(forward_fn, params, samples,
                                               _build_fn)


centered_jacobian_real_holo = compose(tree_subtract_mean, jacobian_real_holo)
centered_jacobian_cplx = compose(tree_subtract_mean, jacobian_cplx)


def _divide_by_sqrt_n_samp(oks, samples):
    """
    divide Oⱼₖ by √n
    """
    n_samp = samples.shape[0] * mpi.n_nodes  # MPI
    return jax.tree_map(lambda x: x / np.sqrt(n_samp, dtype=x.dtype), oks)


def _multiply_by_pdf(oks, pdf):
    """
    Computes  O'ⱼ̨ₖ = Oⱼₖ pⱼ .
    Used to multiply the log-derivatives by the probability density.