Esempio n. 1
0
    def wrapped_exec(params):
        def wrapper(p):
            """Compute the forward pass by returning the jacobian too."""
            new_tapes = []

            for t, a in zip(tapes, p):
                new_tapes.append(t.copy(copy_operations=True))
                new_tapes[-1].set_parameters(a)

            res, jacs = execute_fn(new_tapes, **gradient_kwargs)

            # On the forward execution return the jacobian too
            return res, jacs

        fwd_shapes = [
            jax.ShapeDtypeStruct((1, ), dtype) for _ in range(total_size)
        ]
        jacobian_shape = [
            jax.ShapeDtypeStruct((1, len(p)), dtype) for p in params
        ]
        res, jacs = host_callback.call(
            wrapper,
            params,
            result_shape=tuple([fwd_shapes, jacobian_shape]),
        )
        return res, jacs
Esempio n. 2
0
def _parse_spec(spec):
  """Parse an input spec of the form (shape, dtype) or shape into a jax.ShapeDtypeStruct."""
  spec = tuple(spec)
  if len(spec) == 2 and isinstance(spec[0], Iterable):
    return jax.ShapeDtypeStruct(tuple(spec[0]), spec[1])
  else:
    return jax.ShapeDtypeStruct(spec, jnp.float32)
Esempio n. 3
0
def jax_eig(A):
    complex_dtype = jax.lax.complex(A, A).dtype
    result_shape = (jax.ShapeDtypeStruct((A.shape[0], ), complex_dtype),
                    jax.ShapeDtypeStruct(A.shape, complex_dtype))
    w, v = host_callback.call(host_eig,
                              jax.lax.stop_gradient(A),
                              result_shape=result_shape)
    return w, v
Esempio n. 4
0
 def wrapped_exec(params):
     exec_fn = partial(self.execute_device, device=device)
     return host_callback.call(
         exec_fn,
         params,
         result_shape=jax.ShapeDtypeStruct((1, ), JAXInterface.dtype),
     )
Esempio n. 5
0
    def _sample_next(sampler, machine, parameters, state):
        new_rng, rng = jax.random.split(state.rng)
        numbers = jax.random.choice(
            rng,
            sampler.hilbert.n_states,
            shape=(sampler.n_chains,),
            replace=True,
            p=state.pdf,
        )

        # We use a host-callback to convert integers labelling states to
        # valid state-arrays because that code is written with numba and
        # we have not yet converted it to jax.
        numbers_to_states = lambda numbers: sampler.hilbert.numbers_to_states(numbers)

        sample = hcb.call(
            numbers_to_states,
            numbers,
            result_shape=jax.ShapeDtypeStruct(
                (sampler.n_chains, sampler.hilbert.size), jnp.float64
            ),
        )

        new_state = state.replace(rng=new_rng)
        return new_state, jnp.asarray(sample, dtype=sampler.dtype)
Esempio n. 6
0
def eval_provenance(fun, *args, **kwargs):
    """
    Compute the provenance output of ``fun`` using JAX's abstract
    interpretation machinery. There is no actual array computation performed.

    :param fun: A callable to track provenance of its (keyword) arguments.
    :param args: Positional arguments of `fun`.
    :param kwargs: Keyword arguments of `fun`.
    :returns: A pytree of :class:`ProvenanceArray`.
    """
    # flatten the function and its arguments
    args_flat, in_tree = jax.tree_util.tree_flatten((args, kwargs))
    wrapped_fun, out_tree = jax.api_util.flatten_fun(wrap_init(fun), in_tree)
    fun = wrap_init(wrapped_fun.call_wrapped)
    avals = jax.util.safe_map(jax.api_util.shaped_abstractify, args_flat)

    # execute the function and trace provenance
    with jax.core.new_main(_ProvenanceJaxprTrace, dynamic=True) as main:
        main.jaxpr_stack = ()
        out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1]

    # unflatten the output and get its provenance
    out = [jax.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out]
    out = jax.tree_util.tree_unflatten(out_tree(), out)
    return jax.tree_util.tree_map(
        lambda x: ProvenanceArray(
            x, x.named_shape.get("_provenance", frozenset())),
        out,
    )
Esempio n. 7
0
 def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
     """Create a parameter."""
     self.reserve(name)
     if self.has_variable('params', name):
         abs_rng = jax.ShapeDtypeStruct((2, ), jnp.uint32)
         value = self.get_variable('params', name)
         # validate shape of init_fn output is the same as the shape of the existing
         # parameter.
         abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args),
                                    abs_rng)
         abs_value_flat = jax.tree_leaves(abs_value)
         value_flat = jax.tree_leaves(value)
         for val, abs_val in zip(value_flat, abs_value_flat):
             # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious.
             # we might intentionally change the dtype for inference to a half float type for example.
             if jnp.shape(val) != jnp.shape(abs_val):
                 raise ValueError(
                     'Inconsistent shapes between value and initializer '
                     f'for parameter "{name}" in "{self.path_text}": {jnp.shape(val)}, {jnp.shape(abs_val)}'
                 )
         return value
     else:
         if not self.is_mutable_collection('params'):
             raise ValueError(
                 f'No paramater named "{name}" exists in "{self.path_text}".'
             )
         value = init_fn(self.make_rng('params'), *init_args)
         self.put_variable('params', name, value)
         return value
Esempio n. 8
0
def test_abstract_eval_simple():
    add_two = primitive(
        dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0'))
    x = jax.ShapeDtypeStruct((10, ), np.float32)
    output_shape = jax.eval_shape(add_two, x)
    assert output_shape.shape == (10, )
    assert output_shape.dtype == np.int32
Esempio n. 9
0
    def error_norm(self, error_norm: Union[str, Callable]):
        if isinstance(error_norm, Callable):
            self._error_norm = error_norm
        elif error_norm == "euclidean":
            self._error_norm = euclidean_norm
        elif error_norm == "maximum":
            self._error_norm = maximum_norm
        elif error_norm == "qgt":
            w = self.state.parameters
            norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w))
            # QGT norm is called via host callback since it accesses the driver
            # TODO: make this also an hashablepartial on self to reduce recompilation
            self._error_norm = lambda x: hcb.call(
                HashablePartial(qgt_norm, self),
                x,
                result_shape=jax.ShapeDtypeStruct((), norm_dtype),
            )
        else:
            raise ValueError(
                "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum',"
                f" but {error_norm} was passed."
            )

        if self.integrator is not None:
            self.integrator.norm = self._error_norm
  def test_BilinearPairwiseReadout_shapes(self):
    outs, _ = graph_layers.BilinearPairwiseReadout.init(
        jax.random.PRNGKey(0),
        node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32))

    expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES), jnp.float32)
    self._check_shape_and_dtype(outs, expected)
Esempio n. 11
0
def _sample_chain(
    sampler,
    machine: Callable,
    parameters: PyTree,
    state: SamplerState,
    chain_length: int,
) -> Tuple[jnp.ndarray, SamplerState]:
    new_rng, rng = jax.random.split(state.rng)
    numbers = jax.random.choice(
        rng,
        sampler.hilbert.n_states,
        shape=(chain_length * sampler.n_chains,),
        replace=True,
        p=state.pdf,
    )

    # For future investigators:
    # this will lead to a crash if numbers_to_state throws.
    # it throws if we feed it nans!
    numbers_to_states = lambda numbers: sampler.hilbert.numbers_to_states(numbers)

    samples = hcb.call(
        numbers_to_states,
        numbers,
        result_shape=jax.ShapeDtypeStruct(
            (chain_length * sampler.n_chains, sampler.hilbert.size), jnp.float64
        ),
    )
    samples = jnp.asarray(samples, dtype=sampler.dtype).reshape(
        chain_length, sampler.n_chains, sampler.hilbert.size
    )

    return samples, state.replace(rng=new_rng)
Esempio n. 12
0
def partial_eval_by_shape(fn, input_spec, *args, **kwargs):
  """Lazily evaluate a function by using the shapes of the inputs.

  This function is similar to `jax.eval_shape` with the key difference that
  function outputs that can be computed without a concrete value of the
  inputs are returned as is instead of only the shape. See for example
  `module.init_by_shape` where this functionality is used to initialize a
  model without using input data lr computation.

  Args:
    fn: the function to be lazily evaluated.
    input_spec: an iterable of shapes or (shape, dtype) tuples specifying the
      shape and type of the inputs. If unspecified the dtype is float32.
    *args: other arguments passed to the module's apply function
    **kwargs: keyword arguments passed to the module's apply function
  Returns:
    A pair consisting of the model output and an instance of Model
  """
  # output cannot be returned in lazy_create because jax.eval_shape will only
  # return the shape and dtype.
  # TODO(mattjj,jheek): use a public JAX API
  f = lambda *inputs: fn(*inputs, *args, **kwargs)
  input_structs = [_parse_spec(spec) for spec in input_spec]
  inputs_flat, in_tree = jax.tree_flatten(input_structs)
  f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
  in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
              for x in inputs_flat]

  if _is_omnistaging:
    _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
  else:
    _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)
  out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype)
              for pv, const in out_pvals]
  return jax.tree_unflatten(out_tree(), out_flat)
  def test_NRIEdgeLayer_message_passing_shapes(self, which, message_passing):
    outs, _ = graph_layers.NRIEdgeLayer.init(
        jax.random.PRNGKey(0),
        edge_embeddings=None if which == "pairwise_only" else jnp.zeros(
            (NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM), jnp.float32),
        node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32),
        mlp_vtoe_dims=NRI_HIDDEN_DIMS + (MESSAGE_DIM,),
        mask=jnp.zeros((NUM_NODES, NUM_NODES), jnp.float32),
        allow_non_adjacent=(which != "edges_only"),
        message_passing=message_passing)

    if message_passing:
      expected = jax.ShapeDtypeStruct((NUM_NODES, MESSAGE_DIM), jnp.float32)
    else:
      expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES, MESSAGE_DIM),
                                      jnp.float32)
    self._check_shape_and_dtype(outs, expected)
Esempio n. 14
0
 def testLowerDonateArgnumsAvailable(self):
   x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
   def f(*args):
     x, *_ = args
     return x
   f_low = pjit(f, donate_argnums=(0,),
                in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
   f_com = f_low.compile()
   f_low.donate_argnums == f_com.donate_argnums == (0,)
  def test_gated_recurrent_update_shapes(self):
    outs, _ = graph_layers.gated_recurrent_update.init(
        jax.random.PRNGKey(0),
        node_states=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32),
        messages=jnp.zeros((NUM_NODES, MESSAGE_DIM), jnp.float32))

    expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM),
                                    jnp.float32)
    self._check_shape_and_dtype(outs, expected)
Esempio n. 16
0
 def test_fast_eval_shape_already_transformed(self):
     f = transform.transform(lambda x: basic.Linear(20)(x))  # pylint: disable=unnecessary-lambda
     rng = jax.random.PRNGKey(0)
     x = jnp.ones([1, 12])
     # init_fn
     y_slow = jax.eval_shape(f.init, rng, x)
     y_fast = eval_shape.fast_eval_shape(f.init, rng, x)
     self.assertEqual(y_slow, y_fast)
     self.assertEqual(
         y_slow, {
             'linear': {
                 'w': jax.ShapeDtypeStruct((12, 20), jnp.float32),
                 'b': jax.ShapeDtypeStruct((20, ), jnp.float32)
             }
         })
     # apply_fn
     y_slow = jax.eval_shape(f.apply, y_slow, rng, x)
     y_fast = eval_shape.fast_eval_shape(f.apply, y_fast, rng, x)
     self.assertEqual(y_slow, y_fast)
    def test_NodeTypeNodeEmbedding_shapes(self):
        outs, _ = graph_layers.NodeTypeNodeEmbedding.init(
            jax.random.PRNGKey(0),
            node_types=jnp.zeros((NUM_NODES, ), jnp.int32),
            num_node_types=NUM_NODE_TYPES,
            embedding_dim=NODE_EMBEDDING_DIM)

        expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM),
                                        jnp.float32)
        self._check_shape_and_dtype(outs, expected)
Esempio n. 18
0
    def test_fast_eval_shape_within_transform(self):
        def f(x):
            m = basic.Linear(20)
            y_slow = stateful.eval_shape(m, x)
            y_fast = eval_shape.fast_eval_shape(m, x)
            self.assertEqual(y_slow, y_fast)
            return m(x)

        f = transform.transform(f)
        rng = jax.random.PRNGKey(0)
        x = jnp.ones([1, 12])
        params = jax.eval_shape(f.init, rng, x)
        self.assertEqual(
            params, {
                'linear': {
                    'w': jax.ShapeDtypeStruct((12, 20), jnp.float32),
                    'b': jax.ShapeDtypeStruct((20, ), jnp.float32)
                }
            })
        jax.eval_shape(f.apply, params, rng, x)
  def test_LinearMessagePassing_shapes(self):
    outs, _ = graph_layers.LinearMessagePassing.init(
        jax.random.PRNGKey(0),
        edge_embeddings=jnp.zeros((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM),
                                  jnp.float32),
        node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32),
        message_dim=MESSAGE_DIM,
        with_bias=True)

    expected = jax.ShapeDtypeStruct((NUM_NODES, MESSAGE_DIM), jnp.float32)
    self._check_shape_and_dtype(outs, expected)
Esempio n. 20
0
 def test_no_namescopes_inside_abstract_dot(self):
   mod = AddModule()
   current_setting = module.modules_with_named_call
   a = b = jax.ShapeDtypeStruct(shape=tuple(), dtype=jnp.float32)
   try:
     module.profiler_name_scopes(enabled=True)
     with mock.patch.object(stateful, "named_call") as mock_f:
       _ = dot.abstract_to_dot(mod)(a, b)
       mock_f.assert_not_called()
   finally:
     module.profiler_name_scopes(enabled=current_setting)
Esempio n. 21
0
File: spin.py Progetto: vlpap/netket
def random_state_batch_spin_impl(hilb: Spin, key, batches, dtype):
    S = hilb._s
    shape = (batches, hilb.size)

    # If unconstrained space, use fast sampling
    if hilb._total_sz is None:
        n_states = int(2 * S + 1)
        rs = jax.random.randint(key, shape=shape, minval=0, maxval=n_states)

        two = jnp.asarray(2, dtype=dtype)
        return jnp.asarray(rs * two - (n_states - 1), dtype=dtype)
    else:
        N = hilb.size
        n_states = int(2 * S) + 1
        # if constrained and S == 1/2, use a trick to sample quickly
        if n_states == 2:
            m = hilb._total_sz * 2
            nup = (N + m) // 2
            ndown = (N - m) // 2

            x = jnp.concatenate(
                (
                    jnp.ones((batches, nup), dtype=dtype),
                    -jnp.ones(
                        (
                            batches,
                            ndown,
                        ),
                        dtype=dtype,
                    ),
                ),
                axis=1,
            )

            # deprecated: return jax.random.shuffle(key, x, axis=1)
            return jax.vmap(jax.random.permutation)(
                jax.random.split(key, x.shape[0]), x
            )

        # if constrained and S != 1/2, then use a slow fallback algorithm
        # TODO: find better, faster way to smaple constrained arbitrary spaces.
        else:
            from jax.experimental import host_callback as hcb

            cb = lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype)

            state = hcb.call(
                cb,
                key,
                result_shape=jax.ShapeDtypeStruct(shape, dtype),
            )

            return state
Esempio n. 22
0
def trace(state, fn, num_steps, **_):
    """Implementation of `trace` operator, without the calling convention."""
    # We need the shapes and dtypes of the outputs of `fn`.
    _, untraced_spec, traced_spec = jax.eval_shape(
        fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state))
    untraced_init = map_tree(lambda spec: np.zeros(spec.shape, spec.dtype),
                             untraced_spec)

    try:
        num_steps = int(num_steps)
        use_scan = True
    except TypeError:
        use_scan = False
        if flatten_tree(traced_spec):
            raise ValueError(
                'Cannot trace values when `num_steps` is not statically known. Pass '
                'False to `trace_mask` or return an empty structure (e.g. `()`) as '
                'the extra output.')

    if use_scan:

        def wrapper(state_untraced, _):
            state, _ = state_untraced
            state, untraced, traced = fn(state)
            return (state, untraced), traced

        (state, untraced), traced = lax.scan(
            wrapper,
            (state, untraced_init),
            xs=None,
            length=num_steps,
        )
    else:
        trace_arrays = map_tree(
            lambda spec: np.zeros((num_steps, ) + spec.shape, spec.dtype),
            traced_spec)

        def wrapper(i, state_untraced_traced):
            state, _, trace_arrays = state_untraced_traced
            state, untraced, traced = fn(state)
            trace_arrays = map_tree(lambda a, e: jax.ops.index_update(a, i, e),
                                    trace_arrays, traced)
            return (state, untraced, trace_arrays)

        state, untraced, traced = lax.fori_loop(
            np.asarray(0, num_steps.dtype),
            num_steps,
            wrapper,
            (state, untraced_init, trace_arrays),
        )
    return state, untraced, traced
Esempio n. 23
0
def _einsum_contract_path(*operands, **kwargs):
    """Like opt_einsum.contract_path, with support for DimPolynomial shapes.

  We use opt_einsum.contract_path to compute the schedule, using a fixed
  constant for all dimension variables. This is safe because we throw an
  error if there are more than 1 contractions. Essentially, we just use
  opt_einsum.contract_path to parse the specification.
  """

    # Replace the polymorphic shapes with some concrete shapes for calling
    # into opt_einsum.contract_path, because the latter wants to compute the
    # sizes of operands and intermediate results.
    fake_ops = []
    for operand in operands:
        # We replace only array operands
        if not hasattr(operand, "dtype"):
            fake_ops.append(operand)
        else:
            shape = np.shape(operand)

            def fake_dim(d):
                if core.is_constant_dim(d):
                    return d
                else:
                    if not isinstance(d, _DimPolynomial):
                        raise TypeError(
                            f"Encountered unexpected shape dimension {d}")
                    # It is Ok to replace all polynomials with the same value. We may miss
                    # here some errors due to non-equal dimensions, but we catch them
                    # later.
                    return 8

            fake_ops.append(
                jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)),
                                     operand.dtype))

    contract_fake_ops, contractions = opt_einsum.contract_path(
        *fake_ops, **kwargs)
    if len(contractions) > 1:
        msg = (
            "Shape polymorphism is not yet supported for einsum with more than "
            f"one contraction {contractions}")
        raise ValueError(msg)
    contract_operands = []
    for operand in contract_fake_ops:
        idx = tuple(i for i, fake_op in enumerate(fake_ops)
                    if operand is fake_op)
        assert len(idx) == 1
        contract_operands.append(operands[idx[0]])
    return contract_operands, contractions
Esempio n. 24
0
    def shape_dtype(self) -> jax.ShapeDtypeStruct:
        """Computes the shape and dtype of the result of this `AddN`.

    Returns:
      A `jax.ShapeDtypeStruct` object describing the shape and dtype of the
      `AddN`.
    """
        operand_shape_dtypes = tuple(
            jax.ShapeDtypeStruct(operand.shape, operand.dtype)
            for operand in self.operands)

        def _eval_fun(*args):
            return functools.reduce(operator.add, args)

        return jax.eval_shape(_eval_fun, *operand_shape_dtypes)
Esempio n. 25
0
        def wrapped_exec_bwd(params, g):
            def jacobian(params):
                tape = self.copy()
                tape.set_parameters(params)
                return tape.jacobian(device,
                                     params=params,
                                     **tape.jacobian_options)

            val = g.reshape((-1, )) * host_callback.call(
                jacobian,
                params,
                result_shape=jax.ShapeDtypeStruct(
                    (1, len(params)), JAXInterface.dtype),
            )
            return (list(val.reshape((-1, ))), )  # Comma is on purpose.
    def test_TokenOperatorNodeEmbedding_shapes(self, bottleneck_dim):
        outs, _ = graph_layers.TokenOperatorNodeEmbedding.init(
            jax.random.PRNGKey(0),
            operator=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32),
                output_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32),
                values=jnp.zeros((NUM_TOKENS, ), jnp.int32)),
            vocab_size=VOCAB_SIZE,
            num_nodes=NUM_NODES,
            embedding_dim=NODE_EMBEDDING_DIM,
            bottleneck_dim=bottleneck_dim)

        expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM),
                                        jnp.float32)
        self._check_shape_and_dtype(outs, expected)
  def test_NodeSelfAttention_shapes(self, which):
    like_great = {"like_great": True, "full_relative": False}[which]
    outs, _ = graph_layers.NodeSelfAttention.init(
        jax.random.PRNGKey(0),
        edge_embeddings=jnp.zeros((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM),
                                  jnp.float32),
        node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32),
        heads=NUM_HEADS,
        query_key_dim=QUERY_KEY_DIM,
        value_dim=VALUE_DIM,
        out_dim=MESSAGE_DIM,
        mask=jnp.zeros((NUM_NODES, NUM_NODES), jnp.float32),
        like_great=like_great)

    expected = jax.ShapeDtypeStruct((NUM_NODES, MESSAGE_DIM), jnp.float32)
    self._check_shape_and_dtype(outs, expected)
  def test_LearnableEdgeEmbeddings_shapes(self):
    outs, _ = graph_layers.LearnableEdgeEmbeddings.init(
        jax.random.PRNGKey(0),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=jnp.zeros((NUM_EDGES, 1), jnp.int32),
            output_indices=jnp.zeros((NUM_EDGES, 2), jnp.int32),
            values=jnp.zeros((NUM_EDGES,), jnp.int32)),
        num_nodes=NUM_NODES,
        num_edge_types=NUM_EDGE_TYPES,
        forward_edge_type_indices=[0, 2],
        reverse_edge_type_indices=[3, 1],
        embedding_dim=EDGE_EMBEDDING_DIM)

    expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM),
                                    jnp.float32)
    self._check_shape_and_dtype(outs, expected)
Esempio n. 29
0
def odefun_host_callback(state, driver, *args, **kwargs):
    """
    Calls odefun through a host callback in order to make the rest of the
    ODE solver jit-able.
    """
    result_shape = jax.tree_map(
        lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype),
        state.parameters,
    )

    return hcb.call(
        lambda args_and_kw: odefun(state, driver, *args_and_kw[0], **args_and_kw[1]),
        # pack args and kwargs together, since host_callback passes a single argument:
        (args, kwargs),
        result_shape=result_shape,
    )
Esempio n. 30
0
def random_state(hilb: Fock, key, batches: int, *, dtype=np.float32):
    shape = (batches, hilb.size)

    # If unconstrained space, use fast sampling
    if hilb.n_particles is None:
        rs = jax.random.randint(key, shape=shape, minval=0, maxval=hilb.n_max + 1)
        return jnp.asarray(rs, dtype=dtype)

    else:
        from jax.experimental import host_callback as hcb

        state = hcb.call(
            lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype),
            key,
            result_shape=jax.ShapeDtypeStruct(shape, dtype),
        )

        return state