Beispiel #1
0
def _vdot_real_part(x, y):
    """Vector dot-product guaranteed to have a real valued result."""
    # all our uses of vdot() in CG are for computing an operator of the form
    # `z^T M z` where `M` is positive definite and Hermitian, so the result is
    # real valued:
    # https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
    vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
    result = vdot(x.real, y.real)
    if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
        result += vdot(x.imag, y.imag)
    return result
Beispiel #2
0
def _vdot_real_part(x, y):
    """Vector dot-product guaranteed to have a real valued result despite
     possibly complex input. Thus neglects the real-imaginary cross-terms.
     The result is a real float.
  """
    # all our uses of vdot() in CG are for computing an operator of the form
    #  z^H M z
    #  where M is positive definite and Hermitian, so the result is
    # real valued:
    # https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
    result = _vdot(x.real, y.real)
    if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
        result += _vdot(x.imag, y.imag)
    return result
Beispiel #3
0
 def reim_activation(f, x):
     sqrt2 = jnp.sqrt(jnp.array(2, dtype=x.real.dtype))
     if jnp.iscomplexobj(x):
         return jax.lax.complex(f(sqrt2 * x.real), f(
             sqrt2 * x.imag)) / sqrt2
     else:
         return f(x)
Beispiel #4
0
def odefun(state: MCState, driver: TDVP, t, w, *, stage=0):  # noqa: F811
    # pylint: disable=protected-access

    state.parameters = w
    state.reset()

    driver._loss_stats, driver._loss_grad = state.expect_and_grad(
        driver.generator(t),
        use_covariance=True,
    )
    driver._loss_grad = jax.tree_map(lambda x: driver._loss_grad_factor * x,
                                     driver._loss_grad)

    qgt = driver.qgt(driver.state)
    if stage == 0:  # TODO: This does not work with FSAL.
        driver._last_qgt = qgt

    initial_dw = None if driver.linear_solver_restart else driver._dw
    driver._dw, _ = qgt.solve(driver.linear_solver,
                              driver._loss_grad,
                              x0=initial_dw)

    # If parameters are real, then take only real part of the gradient (if it's complex)
    driver._dw = jax.tree_map(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real),
        driver._dw,
        state.parameters,
    )

    return driver._dw
Beispiel #5
0
    def _forward_and_backward(self):
        """
        Performs a number of VMC optimization steps.

        Args:
            n_steps (int): Number of steps to perform.
        """

        self.state.reset()

        # Compute the local energy estimator and average Energy
        self._loss_stats, self._loss_grad = self.state.expect_and_grad(
            self._ham)

        # if it's the identity it does
        # self._dp = self._loss_grad
        self._dp = self.preconditioner(self.state, self._loss_grad)

        # If parameters are real, then take only real part of the gradient (if it's complex)
        self._dp = jax.tree_map(
            lambda x, target: (x if jnp.iscomplexobj(target) else x.real),
            self._dp,
            self.state.parameters,
        )

        return self._dp
Beispiel #6
0
    def test_optimization(self, opt_name, opt, target, dtype):
        if (opt_name in ('fromage', 'noisy_sgd', 'sm3')
                and jnp.iscomplexobj(dtype)):
            raise absltest.SkipTest(
                f'{opt_name} does not support complex parameters.')

        opt = opt()
        initial_params, final_params, get_updates = target(dtype)

        @jax.jit
        def step(params, state):
            updates = get_updates(params)
            if opt_name == 'dpsgd':
                updates = updates[None]
            # Complex gradients need to be conjugated before being added to parameters
            # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
            updates = jax.tree_map(lambda x: x.conj(), updates)
            updates, state = opt.update(updates, state, params)
            params = update.apply_updates(params, updates)
            return params, state

        params = initial_params
        state = opt.init(params)
        for _ in range(10000):
            params, state = step(params, state)

        chex.assert_tree_all_close(params, final_params, rtol=3e-2, atol=3e-2)
Beispiel #7
0
def tree_conj(t: PyTree) -> PyTree:
    r"""
    Conjugate all complex leaves. The real leaves are left untouched.
    Args:
        t: pytree
    """
    return jax.tree_map(lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t)
Beispiel #8
0
    def _forward_and_backward(self):
        """
        Performs a number of VMC optimization steps.

        Args:
            n_steps (int): Number of steps to perform.
        """

        self.state.reset()

        # Compute the local energy estimator and average Energy
        self._loss_stats, self._loss_grad = self.state.expect_and_grad(
            self._ldag_l)

        if self.sr is not None:
            self._S = self.state.quantum_geometric_tensor(self.sr)

            # use the previous solution as an initial guess to speed up the solution of the linear system
            x0 = self._dp if self.sr_restart is False else None
            self._dp = self._S.solve(self._loss_grad, x0=x0)
        else:
            # tree_map(lambda x, y: x if is_ccomplex(y) else x.real, self._grads, self.state.parameters)
            self._dp = self._loss_grad

        # If parameters are real, then take only real part of the gradient (if it's complex)
        self._dp = jax.tree_multimap(
            lambda x, target: (x if jnp.iscomplexobj(target) else x.real),
            self._dp,
            self.state.parameters,
        )

        return self._dp
Beispiel #9
0
def _to_im(x):
    if jnp.iscomplexobj(x):
        return x.imag
        # TODO find a way to make it a nop?
        # return jax.vmap(lambda y: jnp.array((y.real, y.imag)))(x)
    else:
        return None
Beispiel #10
0
def grad_expect_hermitian_chunked(
    chunk_size: int,
    local_value_kernel_chunked: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel_chunked(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
        chunk_size=chunk_size,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        vjp_fun_chunked = nkjax.vjp_chunked(
            lambda w, σ: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            σ,
            conjugate=True,
            chunk_size=chunk_size,
            chunk_argnums=1,
            nondiff_argnums=1,
        )
        new_model_state = None
    else:
        raise NotImplementedError

    Ō_grad = vjp_fun_chunked(
        (jnp.conjugate(O_loc) / n_samples),
    )[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
Beispiel #11
0
def tree_conj(t):
    r"""
    conjugate all complex leaves
    The real leaves are left untouched.

    t: pytree
    """
    return jax.tree_map(lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t)
Beispiel #12
0
    def check(x, target):
        par_iscomplex = jnp.iscomplexobj(x)

        # Account for split real-imaginary part in Jacobian*** methods
        if isinstance(target, tuple):
            vec_iscomplex = True if len(target) == 2 else False
        else:
            vec_iscomplex = jnp.iscomplexobj(target)

        if not par_iscomplex and vec_iscomplex:
            raise TypeError(
                dedent("""
                    Cannot multiply the (real part of the) QGT by a complex vector.
                    You should either take the real part of the vector, or perform
                    the multiplication against the real and imaginary part of the
                    vector separately and then recomposing the two.
                    """))
Beispiel #13
0
def grad_expect_operator_kernel(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    machine_pow: int,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree, Stats]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    is_mutable = mutable is not False
    logpsi = lambda w, σ: model_apply_fun(
        {"params": w, **model_state}, σ, mutable=mutable
    )
    log_pdf = (
        lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real
    )

    def expect_closure_pars(pars):
        return nkjax.expect(
            log_pdf,
            partial(local_value_kernel, logpsi),
            pars,
            σ,
            local_value_args,
            n_chains=σ_shape[0],
        )

    Ō, Ō_pb, Ō_stats = nkjax.vjp(
        expect_closure_pars, parameters, has_aux=True, conjugate=True
    )
    Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))[0]

    # This term below is needed otherwise it does not match the value obtained by
    # (ha@ha).collect(). I'm unsure of why it is needed.
    Ō_pars_grad = jax.tree_multimap(
        lambda x, target: x / 2 if jnp.iscomplexobj(target) else x,
        Ō_pars_grad,
        parameters,
    )

    if is_mutable:
        raise NotImplementedError(
            "gradient of non-hermitian operators over mutable models "
            "is not yet implemented."
        )
    new_model_state = None

    return (
        Ō_stats,
        jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], Ō_pars_grad),
        new_model_state,
    )
Beispiel #14
0
def grad_expect_hermitian(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * utils.n_nodes

    O_loc = local_cost_function(
        local_value_cost,
        model_apply_fun,
        {"params": parameters, **model_state},
        σp,
        mels,
        σ,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    if mutable is False:
        _, vjp_fun = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ),
            parameters,
            conjugate=True,
        )
        new_model_state = None
    else:
        _, vjp_fun, new_model_state = nkjax.vjp(
            lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
            parameters,
            conjugate=True,
            has_aux=True,
        )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    return Ō, tree_map(sum_inplace, Ō_grad), new_model_state
Beispiel #15
0
def logsumexp_cplx(a, b=None, **kwargs):
    """Compute the log of the sum of exponentials of input elements, always returning a
    complex number.

    Equivalent to, but more numerically stable than, `np.log(np.sum(b*np.exp(a)))`.
    If the optional argument `b` is omitted, `np.log(np.sum(np.exp(a)))` is returned.

    Wraps `jax.scipy.special.logsumexp` but uses `return_sign=True` if both `a` and `b`
    are real numbers in order to support `b<0` instead of returning `nan`.

    See the JAX function for details of the calling sequence;
    `return_sign` is not supported.
    """
    if jnp.iscomplexobj(a) or jnp.iscomplexobj(b):
        # logsumexp uses complex algebra anyway
        return logsumexp(a, b=b, **kwargs)
    else:
        a, sgn = logsumexp(a, b=b, **kwargs, return_sign=True)
        a = a + jnp.where(sgn < 0, 1j * jnp.pi, 0j)
        return a
Beispiel #16
0
def anf(signal, f0, sr, A=1, phi=0, lr=1e-4, device=cpus[0]):
    if jnp.iscomplexobj(signal):
        signal = jnp.stack([signal.real, signal.imag], axis=-1)
        signal = vmap(_anf, in_axes=(-1,) + (None,) * 6, out_axes=(-1,))(
            signal, f0, sr, A, phi, lr, device
        )
        signal = signal[...,0] + jnp.array(1j) * signal[...,1]
    else:
        signal = _anf(signal, f0, sr, A, phi, lr, device)

    return signal
Beispiel #17
0
def test_logsumexp_cplx(a, b):
    a = jnp.asarray(a)
    if b is not None:
        b = jnp.asarray(b)
        expected = jnp.log(
            complex(jnp.exp(a[0]) * b[0] + jnp.exp(a[1]) * b[1]))
    else:
        expected = jnp.log(complex(jnp.exp(a[0]) + jnp.exp(a[1])))
    c = logsumexp_cplx(a, b=b)

    assert jnp.iscomplexobj(c)
    assert_allclose(c, expected, atol=1e-8)
Beispiel #18
0
def _setup_parabola(dtype):
    """Quadratic function as an optimization target."""
    initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
    final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)

    if jnp.iscomplexobj(dtype):
        final_params *= 1 + 1j

    @jax.grad
    def get_updates(params):
        return jnp.sum(numerics.abs_sq(params - final_params))

    return initial_params, final_params, get_updates
Beispiel #19
0
def grad_expect_hermitian(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Tuple[PyTree, PyTree]:

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples = σ.shape[0] * mpi.n_nodes

    O_loc = local_value_kernel(
        model_apply_fun,
        {"params": parameters, **model_state},
        σ,
        local_value_args,
    )

    Ō = statistics(O_loc.reshape(σ_shape[:-1]).T)

    O_loc -= Ō.mean

    # Then compute the vjp.
    # Code is a bit more complex than a standard one because we support
    # mutable state (if it's there)
    is_mutable = mutable is not False
    _, vjp_fun, *new_model_state = nkjax.vjp(
        lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
        parameters,
        conjugate=True,
        has_aux=is_mutable,
    )
    Ō_grad = vjp_fun(jnp.conjugate(O_loc) / n_samples)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    new_model_state = new_model_state[0] if is_mutable else None

    return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
Beispiel #20
0
def gaussian_random_fill(work_array):
    """
    Fill work_array with random values in place.
    """
    key = jax.random.PRNGKey(int(time.time()))
    subkey1, subkey2 = jax.random.split(key, 2)

    output = jax.lax.cond(
        jnp.iscomplexobj(work_array), (subkey1, work_array),
        lambda x: gaussian_random_complex_arr(x).astype(work_array.dtype),
        (subkey2, work_array),
        lambda x: gaussian_random_real_arr(x).astype(work_array.dtype))
    work_array = jax.ops.index_update(work_array, index[:], output)
    return work_array
Beispiel #21
0
def tree_cast(x, target):
    r"""
    Cast each leaf of x to the dtype of the corresponding leaf in target.
    The imaginary part of complex leaves which are cast to real is discarded

    x: a pytree with arrays as leaves
    target: a pytree with the same treedef as x where only the dtypes of the leaves are accessed
    """
    # astype alone would also work, however that raises ComplexWarning when casting complex to real
    # therefore the real is taken first where needed
    return jax.tree_multimap(
        lambda x, target:
        (x if jnp.iscomplexobj(target) else x.real).astype(target.dtype),
        x,
        target,
    )
Beispiel #22
0
def _setup_rosenbrock(dtype):
    """Rosenbrock function as an optimization target."""
    a = 1.0
    b = 100.0

    if jnp.iscomplexobj(dtype):
        a *= 1 + 1j

    initial_params = jnp.array([0.0, 0.0], dtype=dtype)
    final_params = jnp.array([a, a**2], dtype=dtype)

    @jax.grad
    def get_updates(params):
        return (numerics.abs_sq(a - params[0]) +
                b * numerics.abs_sq(params[1] - params[0]**2))

    return initial_params, final_params, get_updates
Beispiel #23
0
 def apply_fun(params, inputs, **kwargs):
     # distinguish between real and complex input because
     # jax is not smart enough to use the good gemm on his own.
     if jnp.iscomplexobj(inputs):
         if use_bias:
             Wr, Wi, br, bi = params
             return jnp.dot(inputs, Wr + 1j * Wi) + br + 1j * bi
         else:
             Wr, Wi = params
             return jnp.dot(inputs, Wr + 1j * Wi)
     else:
         if use_bias:
             Wr, Wi, br, bi = params
             return jnp.dot(inputs,
                            Wr) + 1j * jnp.dot(inputs, Wi) + br + 1j * bi
         else:
             Wr, Wi = params
             return jnp.dot(inputs, Wr) + 1j * jnp.dot(inputs, Wi)
Beispiel #24
0
def c2r(c):
    ''' Unpack complex-valued signal into real-valued signal
    for example, converting
    [[0.+0.j 1.-1.j]
     [2.-2.j 3.-3.j]
     [4.-4.j 5.-5.j]
     [6.-6.j 7.-7.j]]
    to
    [[ 0.  0.  1. -1.]
     [ 2. -2.  3. -3.]
     [ 4. -4.  5. -5.]
     [ 6. -6.  7. -7.]]
    '''
    if jnp.iscomplexobj(c):
        if c.ndim != 2:
            raise ValueError('invalid ndim, expected 2 but got %d' % c.ndim)
        r = jnp.stack([c.real, c.imag], axis=-1).reshape((c.shape[0], -1))
    else:
        r = c
    return r
Beispiel #25
0
def r2c(r):
    ''' Pack real-valued signal into complex-valued signal
    for example, converting
    [[ 0.  0.  1. -1.]
     [ 2. -2.  3. -3.]
     [ 4. -4.  5. -5.]
     [ 6. -6.  7. -7.]]
    to
    [[0.+0.j 1.-1.j]
     [2.-2.j 3.-3.j]
     [4.-4.j 5.-5.j]
     [6.-6.j 7.-7.j]]
    '''
    if not jnp.iscomplexobj(r):
        if r.ndim != 2:
            raise ValueError('invalid ndim, expected 2 but got %d' % r.ndim)
        r = r.reshape((r.shape[0], r.shape[-1] // 2, -1))
        c = r[..., 0] + 1j * r[..., 1]
    else:
        c = r
    return c
Beispiel #26
0
def _exp_grad(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    OΨ: jnp.ndarray,
    Ψ: jnp.ndarray,
) -> Tuple[PyTree, PyTree]:
    is_mutable = mutable is not False

    expval_O = (Ψ.conj() * OΨ).sum()
    ΔOΨ = (OΨ - expval_O * Ψ.conj()) * Ψ

    _, vjp_fun, *new_model_state = nkjax.vjp(
        lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
        parameters,
        conjugate=True,
        has_aux=is_mutable,
    )

    Ō_grad = vjp_fun(ΔOΨ)[0]

    Ō_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else 2 * x.real).astype(
            target.dtype
        ),
        Ō_grad,
        parameters,
    )

    new_model_state = new_model_state[0] if is_mutable else None

    return (
        None,
        Ō_grad,
        expval_O,
        new_model_state,
    )
Beispiel #27
0
def grad_expect_operator_Lrho2(
    model_apply_fun: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:
    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    n_samples_node = σ.shape[0]

    has_aux = mutable is not False
    # if not has_aux:
    #    out_axes = (0, 0)
    # else:
    #    out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    # local_kernel_vmap = jax.vmap(
    #    partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
    # )

    # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1))
    (
        Lρ,
        der_loc_vals,
    ) = _der_local_values_jax._local_values_and_grads_notcentered_kernel(
        logpsi, parameters, σp, mels, σ
    )
    # _der_local_values_jax._local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated
    Lρ = jnp.conjugate(Lρ)

    LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T)
    LdagL_mean = LdagL_stats.mean

    # old implementation
    # this is faster, even though i think the one below should be faster
    # (this works, but... yeah. let's keep it here and delete in a while.)
    grad_fun = jax.vmap(nkjax.grad(logpsi, argnums=0), in_axes=(None, 0), out_axes=0)
    der_logs = grad_fun(parameters, σ)
    der_logs_ave = jax.tree_map(lambda x: mean(x, axis=0), der_logs)

    # TODO
    # NEW IMPLEMENTATION
    # This should be faster, but should benchmark as it seems slower
    # to compute der_logs_ave i can just do a jvp with a ones vector
    # _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters)
    # TODO: this ones_like might produce a complexXX type but we only need floatXX
    # and we cut in 1/2 the # of operations to do.
    # der_logs_ave = d_logpsi(
    #    jnp.ones_like(_logpsi_ave).real / (n_samples_node * utils.n_nodes)
    # )[0]
    der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave)

    def gradfun(der_loc_vals, der_logs_ave):
        par_dims = der_loc_vals.ndim - 1

        _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims)))

        grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - (
            der_logs_ave.conjugate() * LdagL_mean
        )
        return grad

    LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave)

    # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away
    # the imaginary part of the gradient.
    # we do this also for standard gradient of energy.
    # this avoid errors in #867, #789, #850
    LdagL_grad = jax.tree_multimap(
        lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype(
            target.dtype
        ),
        LdagL_grad,
        parameters,
    )

    return (
        LdagL_stats,
        LdagL_grad,
        model_state,
    )
Beispiel #28
0
def abs2(x):
    if jnp.iscomplexobj(x):
        return x.real**2 + x.imag**2
    else:
        return x**2
Beispiel #29
0
def toreal(x):
    if jnp.iscomplexobj(x):
        return jnp.array([x.real, x.imag])
    else:
        return x
Beispiel #30
0
    def groundstate(
        self,
        chi: int,
        system_size: Optional[Tuple[int, int]] = None,
        initial_state: Optional[str] = 'ps',
        initial_noise: Optional[float] = None,
        contraction_options: Optional[dict] = None,
        optimisation_options: Optional[dict] = None,
    ):
        """
        Computes the groundstate of the model by minimising a trial states energy.

        Parameters
        ----------
        chi : int
            The bond-dimension of the (i)PEPS
        system_size : (int, int), optional
            for OBC and PBC: the system size, for INFINITE: no effect
        initial_state : str or jax.numpy.ndarray or PEPS or IPEPS, optional
            The initial state for the optimisation.
            If a string keyword:
                'ps' : the product state of the respective phase (z+ for g > gc ~ 3.5, x+ for g < gc)
                        `initial_noise>0` recommended, since product-states might have zero-gradient
                'z+' : the z+ product state, `initial_noise>0` recommended
                'x+' : the x+ product state, `initial_noise>0` recommended
                'ipeps' : (only for finite systems)
                        finds the (iPEPS) groundstate of the same model on the infinite lattice,
                        which can directly be used for PBC, or is cut off at the boundary for OBC
                TODO 'random' keyword
            If an ndarray:
                1D array of the local state, start from the product state of this local state,
                if it is a z or x eigenstate, `initial_noise=True` recommended
            If PEPS:
                (only for finite systems): PEPS of the initial state
            If IPEPS:
                (only for infinite systems): iPEPS of the initial state
            Default: like 'ps'
        initial_noise : float, optional
            If `initial_noise > 0`: add a random deviation to the initial_state.
            then, `initial_noise` is the relative strength of the deviation
        contraction_options : dict, optional
            options for the PEPS contraction of the energy-expectationvalue, see `TFIM.energy`
        optimisation_options : dict, optional
            options for optimisation. kwargs for `jax_optimise.minimise`

        Returns
        -------
        gs : PEPS or IPEPS
        gs_energy : float
        """

        # parse dicts
        optimisation_options = parse_options(optimisation_options,
                                             OPTIMISATION_DEFAULTS_GS)
        contraction_options = parse_options(contraction_options)

        # parse system size
        if self.bc == INFINITE:
            lx, ly = None, None
        else:
            lx, ly = system_size
            assert lx > 0
            assert ly > 0

        # parse initial guess
        initial_guess = _parse_initial_state(
            initial_state,
            chi,
            self.bc,
            lx,
            ly,
            self.g,
            initial_noise,
            complex_tensors=np.iscomplexobj(optimisation_options['dtype']))

        # define cost_function
        def cost_function(new_tensors):
            new_state = initial_guess.with_different_tensors(new_tensors)
            energy = self.energy(new_state, **contraction_options)
            return np.reshape(energy, ())

        # optimisation
        optimal_tensors, optimal_energy, info = minimise(
            cost_function, initial_guess.get_tensors(), **optimisation_options)
        optimal_state = initial_guess.with_different_tensors(optimal_tensors)

        return optimal_state, optimal_energy