def __mul__(self, scalar): if not is_scalar(scalar): # we will overload this as matrix multiplication return self._op__matmul__(scalar) dtype = np.promote_types(self.dtype, _dtype(scalar)) op = self.copy(dtype=dtype) return op.__imul__(scalar)
def __imul__(self, scalar): if not is_scalar(scalar): # we will overload this as matrix multiplication self._op__imatmul__(scalar) if not np.can_cast(_dtype(scalar), self.dtype, casting="same_kind"): raise ValueError( f"Cannot multiply inplace scalar with dtype {type(scalar)} " f"to operator with dtype {self.dtype}") scalar = np.array(scalar, dtype=self.dtype).item() self._operators = tree_map(lambda x: x * scalar, self._operators) self._constant *= scalar self._reset_caches() return self
def tree_axpy(a: Scalar, x: PyTree, y: PyTree) -> PyTree: r""" compute a * x + y Args: a: scalar x, y: pytrees with the same treedef Returns: The sum of the respective leaves of the two pytrees x and y where the leaves of x are first scaled with a. """ if is_scalar(a): return jax.tree_multimap(lambda x_, y_: a * x_ + y_, x, y) else: return jax.tree_multimap(lambda a_, x_, y_: a_ * x_ + y_, a, x, y)
def __iadd__(self, other): if is_scalar(other): if not _isclose(other, 0.0): self._constant += other return self if not isinstance(other, FermionOperator2nd): raise NotImplementedError if not self.hilbert == other.hilbert: raise ValueError( f"Can only add identical hilbert spaces (got A+B, A={self.hilbert}, " "B={other.hilbert})") if not np.can_cast(_dtype(other), self.dtype, casting="same_kind"): raise ValueError( f"Cannot add inplace operator with dtype {type(other)} " f"to operator with dtype {self.dtype}") for t, w in other._operators.items(): if t in self._operators.keys(): self._operators[t] += w else: self._operators[t] = w self._constant += other._constant self._reset_caches() return self
def build_SR(*args, solver_restart: bool = False, **kwargs): """ Construct the structure holding the parameters for using the Stochastic Reconfiguration/Natural gradient method. Depending on the arguments, an implementation is chosen. For details on all possible kwargs check the specific SR implementations in the documentation. You can also construct one of those structures directly. Args: diag_shift: Diagonal shift added to the S matrix method: (cg, gmres) The specific method. iterative: Whever to use an iterative method or not. jacobian: Differentiation mode to precompute gradients can be "holomorphic", "R2R", "R2C", None (if they shouldn't be precomputed) rescale_shift: Whether to rescale the diagonal offsets in SR according to diagonal entries (only with precomputed gradients) Returns: The SR parameter structure. """ # try to understand if this is the old API or new # API old_api = False # new_api = False if "matrix" in kwargs: # new syntax return _SR(*args, **kwargs) legacy_kwargs = ["iterative", "method"] legacy_solver_kwargs = [ "tol", "atol", "maxiter", "M", "restart", "solve_method" ] for key in legacy_kwargs + legacy_solver_kwargs: if key in kwargs: old_api = True break if len(args) > 0: if is_scalar(args[0]): # it's diag_shift old_api = True # else: # new_api = True if len(args) > 1: if isinstance(args[1], str): old_api = True # else: # new_api = True if old_api: for (i, arg) in enumerate(args): if i == 0: # diag shift kwargs["diag_shift"] = arg elif i == 1: kwargs["method"] = arg else: raise TypeError( "SR takes at most 2 positional arguments but len(args) where provided" ) args = tuple() solver = None if "iterative" in kwargs: kwargs.pop("iterative") if "method" in kwargs: legacy_solvers = { "cg": jax.scipy.sparse.linalg.cg, "gmres": jax.scipy.sparse.linalg.gmres, } if kwargs["method"] not in legacy_solvers: raise ValueError( "The old API only supports cg and gmres solvers. " "Migrate to the new API and use any solver from" "jax.scipy.sparse.linalg") solver = legacy_solvers[kwargs["method"]] kwargs.pop("method") else: solver = jax.scipy.sparse.linalg.cg solver_keys = {} has_solver_kw = False for key in legacy_solver_kwargs: if key in kwargs: solver_keys[key] = kwargs[key] has_solver_kw = True if has_solver_kw: solver = partial(solver, **solver_keys) kwargs["solver"] = solver return _SR(*args, solver_restart=solver_restart, **kwargs)