def default_qgt_matrix(variational_state, solver=False, **kwargs): """ Determines default metric tensor depending on variational_state and sovler """ from netket.vqs import ExactState if isinstance(variational_state, ExactState): return partial(QGTJacobianPyTree, **kwargs) n_param_leaves = len(jax.tree_leaves(variational_state.parameters)) n_params = variational_state.n_parameters # those require dense matrix that is known to be faster for this qgt if _is_dense_solver(solver): return partial(QGTJacobianDense, **kwargs) # arbitrary heuristic: if the network's parameters has many leaves # (an rbm has 3) then JacobianDense might be faster # the numbers chosen below are rather arbitrary and should be tuned. if n_param_leaves > 6 and n_params > 800: if nkjax.tree_ishomogeneous(variational_state.variables): return partial(QGTJacobianDense, **kwargs) else: return partial(QGTJacobianPyTree, **kwargs) else: return partial(QGTOnTheFly, **kwargs)
def O_mean(forward_fn, params, samples, holomorphic=True): r""" compute \langle O \rangle i.e. the mean of the rows of the jacobian of forward_fn """ # determine the output type of the forward pass dtype = jax.eval_shape(forward_fn, params, samples).dtype w = jnp.ones(samples.shape[0], dtype=dtype) * (1.0 / (samples.shape[0] * mpi.n_nodes)) homogeneous = nkjax.tree_ishomogeneous(params) real_params = not nkjax.tree_leaf_iscomplex(params) real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples)) if homogeneous and (real_params or holomorphic): if real_params and not real_out: # R->C return O_vjp_rc(forward_fn, params, samples, w) else: # R->R and holomorphic C->C return O_vjp(forward_fn, params, samples, w) else: # R&C -> C # non-holomorphic # C->R assert False
def DeltaOdagger_DeltaO_v(forward_fn, params, samples, v, holomorphic=True): r""" compute \langle \Delta O^\dagger \Delta O \rangle v where \Delta O = O - \langle O \rangle """ homogeneous = nkjax.tree_ishomogeneous(params) real_params = not nkjax.tree_leaf_iscomplex(params) # real_out = not nkjax.is_complex(jax.eval_shape(forward_fn, params, samples)) if not (homogeneous and (real_params or holomorphic)): # everything except R->R, holomorphic C->C and R->C params, reassemble = nkjax.tree_to_real(params) v, _ = nkjax.tree_to_real(v) _forward_fn = forward_fn def forward_fn(p, x): return _forward_fn(reassemble(p), x) omean = O_mean(forward_fn, params, samples, holomorphic=holomorphic) def forward_fn_centered(p, x): return forward_fn(p, x) - tree_dot(p, omean) res = Odagger_O_v(forward_fn_centered, params, samples, v) if not (homogeneous and (real_params or holomorphic)): res = reassemble(res) return res
def _choose_jacobian_mode(apply_fun, pars, model_state, samples, mode, holomorphic): homogeneous_vars = nkjax.tree_ishomogeneous(pars) if holomorphic is True: if not homogeneous_vars: warnings.warn( dedent( """The ansatz has non homogeneous variables, which might not behave well with the holomorhic implemnetation. Use `holomorphic=False` or mode='complex' for more accurate results but lower performance. """ ) ) mode = "holomorphic" else: leaf_iscomplex = nkjax.tree_leaf_iscomplex(pars) complex_output = nkjax.is_complex( jax.eval_shape( apply_fun, {"params": pars, **model_state}, samples.reshape(-1, samples.shape[-1]), ) ) if complex_output: if leaf_iscomplex: if holomorphic is None: warnings.warn( dedent( """ Complex-to-Complex model detected. Defaulting to `holomorphic=False` for the implementation of QGTJacobianDense. If your model is holomorphic, specify `holomorphic=True` to use a more performant implementation. To suppress this warning specify `holomorphic`. """ ), UserWarning, ) mode = "complex" else: mode = "complex" else: mode = "real" if mode == "real": return 0 elif mode == "complex": return 1 elif mode == "holomorphic": return 2 else: raise ValueError(f"unknown mode {mode}")