Exemplo n.º 1
0
 def __mul__(self, scalar):
     if not is_scalar(scalar):
         # we will overload this as matrix multiplication
         return self._op__matmul__(scalar)
     dtype = np.promote_types(self.dtype, _dtype(scalar))
     op = self.copy(dtype=dtype)
     return op.__imul__(scalar)
Exemplo n.º 2
0
 def __imul__(self, scalar):
     if not is_scalar(scalar):
         # we will overload this as matrix multiplication
         self._op__imatmul__(scalar)
     if not np.can_cast(_dtype(scalar), self.dtype, casting="same_kind"):
         raise ValueError(
             f"Cannot multiply inplace scalar with dtype {type(scalar)} "
             f"to operator with dtype {self.dtype}")
     scalar = np.array(scalar, dtype=self.dtype).item()
     self._operators = tree_map(lambda x: x * scalar, self._operators)
     self._constant *= scalar
     self._reset_caches()
     return self
Exemplo n.º 3
0
def tree_axpy(a: Scalar, x: PyTree, y: PyTree) -> PyTree:
    r"""
    compute a * x + y

    Args:
      a: scalar
      x, y: pytrees with the same treedef
    Returns:
        The sum of the respective leaves of the two pytrees x and y
        where the leaves of x are first scaled with a.
    """
    if is_scalar(a):
        return jax.tree_multimap(lambda x_, y_: a * x_ + y_, x, y)
    else:
        return jax.tree_multimap(lambda a_, x_, y_: a_ * x_ + y_, a, x, y)
Exemplo n.º 4
0
 def __iadd__(self, other):
     if is_scalar(other):
         if not _isclose(other, 0.0):
             self._constant += other
         return self
     if not isinstance(other, FermionOperator2nd):
         raise NotImplementedError
     if not self.hilbert == other.hilbert:
         raise ValueError(
             f"Can only add identical hilbert spaces (got A+B, A={self.hilbert}, "
             "B={other.hilbert})")
     if not np.can_cast(_dtype(other), self.dtype, casting="same_kind"):
         raise ValueError(
             f"Cannot add inplace operator with dtype {type(other)} "
             f"to operator with dtype {self.dtype}")
     for t, w in other._operators.items():
         if t in self._operators.keys():
             self._operators[t] += w
         else:
             self._operators[t] = w
     self._constant += other._constant
     self._reset_caches()
     return self
Exemplo n.º 5
0
def build_SR(*args, solver_restart: bool = False, **kwargs):
    """
    Construct the structure holding the parameters for using the
    Stochastic Reconfiguration/Natural gradient method.

    Depending on the arguments, an implementation is chosen. For
    details on all possible kwargs check the specific SR implementations
    in the documentation.

    You can also construct one of those structures directly.

    Args:
        diag_shift: Diagonal shift added to the S matrix
        method: (cg, gmres) The specific method.
        iterative: Whever to use an iterative method or not.
        jacobian: Differentiation mode to precompute gradients
                  can be "holomorphic", "R2R", "R2C",
                  None (if they shouldn't be precomputed)
        rescale_shift: Whether to rescale the diagonal offsets in SR according
                       to diagonal entries (only with precomputed gradients)

    Returns:
        The SR parameter structure.
    """

    #  try to understand if this is the old API or new
    # API

    old_api = False
    # new_api = False

    if "matrix" in kwargs:
        # new syntax
        return _SR(*args, **kwargs)

    legacy_kwargs = ["iterative", "method"]
    legacy_solver_kwargs = [
        "tol", "atol", "maxiter", "M", "restart", "solve_method"
    ]
    for key in legacy_kwargs + legacy_solver_kwargs:
        if key in kwargs:
            old_api = True
            break

    if len(args) > 0:
        if is_scalar(args[0]):  # it's diag_shift
            old_api = True
        # else:
        #    new_api = True

        if len(args) > 1:
            if isinstance(args[1], str):
                old_api = True
        #     else:
        #        new_api = True

    if old_api:
        for (i, arg) in enumerate(args):
            if i == 0:
                # diag shift
                kwargs["diag_shift"] = arg
            elif i == 1:
                kwargs["method"] = arg
            else:
                raise TypeError(
                    "SR takes at most 2 positional arguments but len(args) where provided"
                )

        args = tuple()

        solver = None
        if "iterative" in kwargs:
            kwargs.pop("iterative")
        if "method" in kwargs:
            legacy_solvers = {
                "cg": jax.scipy.sparse.linalg.cg,
                "gmres": jax.scipy.sparse.linalg.gmres,
            }
            if kwargs["method"] not in legacy_solvers:
                raise ValueError(
                    "The old API only supports cg and gmres solvers. "
                    "Migrate to the new API and use any solver from"
                    "jax.scipy.sparse.linalg")
            solver = legacy_solvers[kwargs["method"]]
            kwargs.pop("method")
        else:
            solver = jax.scipy.sparse.linalg.cg

        solver_keys = {}
        has_solver_kw = False
        for key in legacy_solver_kwargs:
            if key in kwargs:
                solver_keys[key] = kwargs[key]
                has_solver_kw = True

        if has_solver_kw:
            solver = partial(solver, **solver_keys)

        kwargs["solver"] = solver

    return _SR(*args, solver_restart=solver_restart, **kwargs)