def wrapped_exec(params): def wrapper(p): """Compute the forward pass by returning the jacobian too.""" new_tapes = [] for t, a in zip(tapes, p): new_tapes.append(t.copy(copy_operations=True)) new_tapes[-1].set_parameters(a) res, jacs = execute_fn(new_tapes, **gradient_kwargs) # On the forward execution return the jacobian too return res, jacs fwd_shapes = [ jax.ShapeDtypeStruct((1, ), dtype) for _ in range(total_size) ] jacobian_shape = [ jax.ShapeDtypeStruct((1, len(p)), dtype) for p in params ] res, jacs = host_callback.call( wrapper, params, result_shape=tuple([fwd_shapes, jacobian_shape]), ) return res, jacs
def _parse_spec(spec): """Parse an input spec of the form (shape, dtype) or shape into a jax.ShapeDtypeStruct.""" spec = tuple(spec) if len(spec) == 2 and isinstance(spec[0], Iterable): return jax.ShapeDtypeStruct(tuple(spec[0]), spec[1]) else: return jax.ShapeDtypeStruct(spec, jnp.float32)
def jax_eig(A): complex_dtype = jax.lax.complex(A, A).dtype result_shape = (jax.ShapeDtypeStruct((A.shape[0], ), complex_dtype), jax.ShapeDtypeStruct(A.shape, complex_dtype)) w, v = host_callback.call(host_eig, jax.lax.stop_gradient(A), result_shape=result_shape) return w, v
def wrapped_exec(params): exec_fn = partial(self.execute_device, device=device) return host_callback.call( exec_fn, params, result_shape=jax.ShapeDtypeStruct((1, ), JAXInterface.dtype), )
def _sample_next(sampler, machine, parameters, state): new_rng, rng = jax.random.split(state.rng) numbers = jax.random.choice( rng, sampler.hilbert.n_states, shape=(sampler.n_chains,), replace=True, p=state.pdf, ) # We use a host-callback to convert integers labelling states to # valid state-arrays because that code is written with numba and # we have not yet converted it to jax. numbers_to_states = lambda numbers: sampler.hilbert.numbers_to_states(numbers) sample = hcb.call( numbers_to_states, numbers, result_shape=jax.ShapeDtypeStruct( (sampler.n_chains, sampler.hilbert.size), jnp.float64 ), ) new_state = state.replace(rng=new_rng) return new_state, jnp.asarray(sample, dtype=sampler.dtype)
def eval_provenance(fun, *args, **kwargs): """ Compute the provenance output of ``fun`` using JAX's abstract interpretation machinery. There is no actual array computation performed. :param fun: A callable to track provenance of its (keyword) arguments. :param args: Positional arguments of `fun`. :param kwargs: Keyword arguments of `fun`. :returns: A pytree of :class:`ProvenanceArray`. """ # flatten the function and its arguments args_flat, in_tree = jax.tree_util.tree_flatten((args, kwargs)) wrapped_fun, out_tree = jax.api_util.flatten_fun(wrap_init(fun), in_tree) fun = wrap_init(wrapped_fun.call_wrapped) avals = jax.util.safe_map(jax.api_util.shaped_abstractify, args_flat) # execute the function and trace provenance with jax.core.new_main(_ProvenanceJaxprTrace, dynamic=True) as main: main.jaxpr_stack = () out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1] # unflatten the output and get its provenance out = [jax.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape) for x in out] out = jax.tree_util.tree_unflatten(out_tree(), out) return jax.tree_util.tree_map( lambda x: ProvenanceArray( x, x.named_shape.get("_provenance", frozenset())), out, )
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: """Create a parameter.""" self.reserve(name) if self.has_variable('params', name): abs_rng = jax.ShapeDtypeStruct((2, ), jnp.uint32) value = self.get_variable('params', name) # validate shape of init_fn output is the same as the shape of the existing # parameter. abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng) abs_value_flat = jax.tree_leaves(abs_value) value_flat = jax.tree_leaves(value) for val, abs_val in zip(value_flat, abs_value_flat): # NOTE: we could check dtype consistency here as well but it's usefuleness is less obvious. # we might intentionally change the dtype for inference to a half float type for example. if jnp.shape(val) != jnp.shape(abs_val): raise ValueError( 'Inconsistent shapes between value and initializer ' f'for parameter "{name}" in "{self.path_text}": {jnp.shape(val)}, {jnp.shape(abs_val)}' ) return value else: if not self.is_mutable_collection('params'): raise ValueError( f'No paramater named "{name}" exists in "{self.path_text}".' ) value = init_fn(self.make_rng('params'), *init_args) self.put_variable('params', name, value) return value
def test_abstract_eval_simple(): add_two = primitive( dex.eval(r'\x:((Fin 10)=>Float). for i. FToI $ x.i + 2.0')) x = jax.ShapeDtypeStruct((10, ), np.float32) output_shape = jax.eval_shape(add_two, x) assert output_shape.shape == (10, ) assert output_shape.dtype == np.int32
def error_norm(self, error_norm: Union[str, Callable]): if isinstance(error_norm, Callable): self._error_norm = error_norm elif error_norm == "euclidean": self._error_norm = euclidean_norm elif error_norm == "maximum": self._error_norm = maximum_norm elif error_norm == "qgt": w = self.state.parameters norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w)) # QGT norm is called via host callback since it accesses the driver # TODO: make this also an hashablepartial on self to reduce recompilation self._error_norm = lambda x: hcb.call( HashablePartial(qgt_norm, self), x, result_shape=jax.ShapeDtypeStruct((), norm_dtype), ) else: raise ValueError( "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'," f" but {error_norm} was passed." ) if self.integrator is not None: self.integrator.norm = self._error_norm
def test_BilinearPairwiseReadout_shapes(self): outs, _ = graph_layers.BilinearPairwiseReadout.init( jax.random.PRNGKey(0), node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32)) expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES), jnp.float32) self._check_shape_and_dtype(outs, expected)
def _sample_chain( sampler, machine: Callable, parameters: PyTree, state: SamplerState, chain_length: int, ) -> Tuple[jnp.ndarray, SamplerState]: new_rng, rng = jax.random.split(state.rng) numbers = jax.random.choice( rng, sampler.hilbert.n_states, shape=(chain_length * sampler.n_chains,), replace=True, p=state.pdf, ) # For future investigators: # this will lead to a crash if numbers_to_state throws. # it throws if we feed it nans! numbers_to_states = lambda numbers: sampler.hilbert.numbers_to_states(numbers) samples = hcb.call( numbers_to_states, numbers, result_shape=jax.ShapeDtypeStruct( (chain_length * sampler.n_chains, sampler.hilbert.size), jnp.float64 ), ) samples = jnp.asarray(samples, dtype=sampler.dtype).reshape( chain_length, sampler.n_chains, sampler.hilbert.size ) return samples, state.replace(rng=new_rng)
def partial_eval_by_shape(fn, input_spec, *args, **kwargs): """Lazily evaluate a function by using the shapes of the inputs. This function is similar to `jax.eval_shape` with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example `module.init_by_shape` where this functionality is used to initialize a model without using input data lr computation. Args: fn: the function to be lazily evaluated. input_spec: an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32. *args: other arguments passed to the module's apply function **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and an instance of Model """ # output cannot be returned in lazy_create because jax.eval_shape will only # return the shape and dtype. # TODO(mattjj,jheek): use a public JAX API f = lambda *inputs: fn(*inputs, *args, **kwargs) input_structs = [_parse_spec(spec) for spec in input_spec] inputs_flat, in_tree = jax.tree_flatten(input_structs) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(f), in_tree) in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)) for x in inputs_flat] if _is_omnistaging: _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) else: _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) out_flat = [const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) for pv, const in out_pvals] return jax.tree_unflatten(out_tree(), out_flat)
def test_NRIEdgeLayer_message_passing_shapes(self, which, message_passing): outs, _ = graph_layers.NRIEdgeLayer.init( jax.random.PRNGKey(0), edge_embeddings=None if which == "pairwise_only" else jnp.zeros( (NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM), jnp.float32), node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32), mlp_vtoe_dims=NRI_HIDDEN_DIMS + (MESSAGE_DIM,), mask=jnp.zeros((NUM_NODES, NUM_NODES), jnp.float32), allow_non_adjacent=(which != "edges_only"), message_passing=message_passing) if message_passing: expected = jax.ShapeDtypeStruct((NUM_NODES, MESSAGE_DIM), jnp.float32) else: expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES, MESSAGE_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def testLowerDonateArgnumsAvailable(self): x = jax.ShapeDtypeStruct((2, 2), jnp.float32) def f(*args): x, *_ = args return x f_low = pjit(f, donate_argnums=(0,), in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x) f_com = f_low.compile() f_low.donate_argnums == f_com.donate_argnums == (0,)
def test_gated_recurrent_update_shapes(self): outs, _ = graph_layers.gated_recurrent_update.init( jax.random.PRNGKey(0), node_states=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32), messages=jnp.zeros((NUM_NODES, MESSAGE_DIM), jnp.float32)) expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def test_fast_eval_shape_already_transformed(self): f = transform.transform(lambda x: basic.Linear(20)(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(0) x = jnp.ones([1, 12]) # init_fn y_slow = jax.eval_shape(f.init, rng, x) y_fast = eval_shape.fast_eval_shape(f.init, rng, x) self.assertEqual(y_slow, y_fast) self.assertEqual( y_slow, { 'linear': { 'w': jax.ShapeDtypeStruct((12, 20), jnp.float32), 'b': jax.ShapeDtypeStruct((20, ), jnp.float32) } }) # apply_fn y_slow = jax.eval_shape(f.apply, y_slow, rng, x) y_fast = eval_shape.fast_eval_shape(f.apply, y_fast, rng, x) self.assertEqual(y_slow, y_fast)
def test_NodeTypeNodeEmbedding_shapes(self): outs, _ = graph_layers.NodeTypeNodeEmbedding.init( jax.random.PRNGKey(0), node_types=jnp.zeros((NUM_NODES, ), jnp.int32), num_node_types=NUM_NODE_TYPES, embedding_dim=NODE_EMBEDDING_DIM) expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def test_fast_eval_shape_within_transform(self): def f(x): m = basic.Linear(20) y_slow = stateful.eval_shape(m, x) y_fast = eval_shape.fast_eval_shape(m, x) self.assertEqual(y_slow, y_fast) return m(x) f = transform.transform(f) rng = jax.random.PRNGKey(0) x = jnp.ones([1, 12]) params = jax.eval_shape(f.init, rng, x) self.assertEqual( params, { 'linear': { 'w': jax.ShapeDtypeStruct((12, 20), jnp.float32), 'b': jax.ShapeDtypeStruct((20, ), jnp.float32) } }) jax.eval_shape(f.apply, params, rng, x)
def test_LinearMessagePassing_shapes(self): outs, _ = graph_layers.LinearMessagePassing.init( jax.random.PRNGKey(0), edge_embeddings=jnp.zeros((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM), jnp.float32), node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32), message_dim=MESSAGE_DIM, with_bias=True) expected = jax.ShapeDtypeStruct((NUM_NODES, MESSAGE_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def test_no_namescopes_inside_abstract_dot(self): mod = AddModule() current_setting = module.modules_with_named_call a = b = jax.ShapeDtypeStruct(shape=tuple(), dtype=jnp.float32) try: module.profiler_name_scopes(enabled=True) with mock.patch.object(stateful, "named_call") as mock_f: _ = dot.abstract_to_dot(mod)(a, b) mock_f.assert_not_called() finally: module.profiler_name_scopes(enabled=current_setting)
def random_state_batch_spin_impl(hilb: Spin, key, batches, dtype): S = hilb._s shape = (batches, hilb.size) # If unconstrained space, use fast sampling if hilb._total_sz is None: n_states = int(2 * S + 1) rs = jax.random.randint(key, shape=shape, minval=0, maxval=n_states) two = jnp.asarray(2, dtype=dtype) return jnp.asarray(rs * two - (n_states - 1), dtype=dtype) else: N = hilb.size n_states = int(2 * S) + 1 # if constrained and S == 1/2, use a trick to sample quickly if n_states == 2: m = hilb._total_sz * 2 nup = (N + m) // 2 ndown = (N - m) // 2 x = jnp.concatenate( ( jnp.ones((batches, nup), dtype=dtype), -jnp.ones( ( batches, ndown, ), dtype=dtype, ), ), axis=1, ) # deprecated: return jax.random.shuffle(key, x, axis=1) return jax.vmap(jax.random.permutation)( jax.random.split(key, x.shape[0]), x ) # if constrained and S != 1/2, then use a slow fallback algorithm # TODO: find better, faster way to smaple constrained arbitrary spaces. else: from jax.experimental import host_callback as hcb cb = lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype) state = hcb.call( cb, key, result_shape=jax.ShapeDtypeStruct(shape, dtype), ) return state
def trace(state, fn, num_steps, **_): """Implementation of `trace` operator, without the calling convention.""" # We need the shapes and dtypes of the outputs of `fn`. _, untraced_spec, traced_spec = jax.eval_shape( fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state)) untraced_init = map_tree(lambda spec: np.zeros(spec.shape, spec.dtype), untraced_spec) try: num_steps = int(num_steps) use_scan = True except TypeError: use_scan = False if flatten_tree(traced_spec): raise ValueError( 'Cannot trace values when `num_steps` is not statically known. Pass ' 'False to `trace_mask` or return an empty structure (e.g. `()`) as ' 'the extra output.') if use_scan: def wrapper(state_untraced, _): state, _ = state_untraced state, untraced, traced = fn(state) return (state, untraced), traced (state, untraced), traced = lax.scan( wrapper, (state, untraced_init), xs=None, length=num_steps, ) else: trace_arrays = map_tree( lambda spec: np.zeros((num_steps, ) + spec.shape, spec.dtype), traced_spec) def wrapper(i, state_untraced_traced): state, _, trace_arrays = state_untraced_traced state, untraced, traced = fn(state) trace_arrays = map_tree(lambda a, e: jax.ops.index_update(a, i, e), trace_arrays, traced) return (state, untraced, trace_arrays) state, untraced, traced = lax.fori_loop( np.asarray(0, num_steps.dtype), num_steps, wrapper, (state, untraced_init, trace_arrays), ) return state, untraced, traced
def _einsum_contract_path(*operands, **kwargs): """Like opt_einsum.contract_path, with support for DimPolynomial shapes. We use opt_einsum.contract_path to compute the schedule, using a fixed constant for all dimension variables. This is safe because we throw an error if there are more than 1 contractions. Essentially, we just use opt_einsum.contract_path to parse the specification. """ # Replace the polymorphic shapes with some concrete shapes for calling # into opt_einsum.contract_path, because the latter wants to compute the # sizes of operands and intermediate results. fake_ops = [] for operand in operands: # We replace only array operands if not hasattr(operand, "dtype"): fake_ops.append(operand) else: shape = np.shape(operand) def fake_dim(d): if core.is_constant_dim(d): return d else: if not isinstance(d, _DimPolynomial): raise TypeError( f"Encountered unexpected shape dimension {d}") # It is Ok to replace all polynomials with the same value. We may miss # here some errors due to non-equal dimensions, but we catch them # later. return 8 fake_ops.append( jax.ShapeDtypeStruct(tuple(map(fake_dim, shape)), operand.dtype)) contract_fake_ops, contractions = opt_einsum.contract_path( *fake_ops, **kwargs) if len(contractions) > 1: msg = ( "Shape polymorphism is not yet supported for einsum with more than " f"one contraction {contractions}") raise ValueError(msg) contract_operands = [] for operand in contract_fake_ops: idx = tuple(i for i, fake_op in enumerate(fake_ops) if operand is fake_op) assert len(idx) == 1 contract_operands.append(operands[idx[0]]) return contract_operands, contractions
def shape_dtype(self) -> jax.ShapeDtypeStruct: """Computes the shape and dtype of the result of this `AddN`. Returns: A `jax.ShapeDtypeStruct` object describing the shape and dtype of the `AddN`. """ operand_shape_dtypes = tuple( jax.ShapeDtypeStruct(operand.shape, operand.dtype) for operand in self.operands) def _eval_fun(*args): return functools.reduce(operator.add, args) return jax.eval_shape(_eval_fun, *operand_shape_dtypes)
def wrapped_exec_bwd(params, g): def jacobian(params): tape = self.copy() tape.set_parameters(params) return tape.jacobian(device, params=params, **tape.jacobian_options) val = g.reshape((-1, )) * host_callback.call( jacobian, params, result_shape=jax.ShapeDtypeStruct( (1, len(params)), JAXInterface.dtype), ) return (list(val.reshape((-1, ))), ) # Comma is on purpose.
def test_TokenOperatorNodeEmbedding_shapes(self, bottleneck_dim): outs, _ = graph_layers.TokenOperatorNodeEmbedding.init( jax.random.PRNGKey(0), operator=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32), output_indices=jnp.zeros((NUM_TOKENS, 1), jnp.int32), values=jnp.zeros((NUM_TOKENS, ), jnp.int32)), vocab_size=VOCAB_SIZE, num_nodes=NUM_NODES, embedding_dim=NODE_EMBEDDING_DIM, bottleneck_dim=bottleneck_dim) expected = jax.ShapeDtypeStruct((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def test_NodeSelfAttention_shapes(self, which): like_great = {"like_great": True, "full_relative": False}[which] outs, _ = graph_layers.NodeSelfAttention.init( jax.random.PRNGKey(0), edge_embeddings=jnp.zeros((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM), jnp.float32), node_embeddings=jnp.zeros((NUM_NODES, NODE_EMBEDDING_DIM), jnp.float32), heads=NUM_HEADS, query_key_dim=QUERY_KEY_DIM, value_dim=VALUE_DIM, out_dim=MESSAGE_DIM, mask=jnp.zeros((NUM_NODES, NUM_NODES), jnp.float32), like_great=like_great) expected = jax.ShapeDtypeStruct((NUM_NODES, MESSAGE_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def test_LearnableEdgeEmbeddings_shapes(self): outs, _ = graph_layers.LearnableEdgeEmbeddings.init( jax.random.PRNGKey(0), edges=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((NUM_EDGES, 1), jnp.int32), output_indices=jnp.zeros((NUM_EDGES, 2), jnp.int32), values=jnp.zeros((NUM_EDGES,), jnp.int32)), num_nodes=NUM_NODES, num_edge_types=NUM_EDGE_TYPES, forward_edge_type_indices=[0, 2], reverse_edge_type_indices=[3, 1], embedding_dim=EDGE_EMBEDDING_DIM) expected = jax.ShapeDtypeStruct((NUM_NODES, NUM_NODES, EDGE_EMBEDDING_DIM), jnp.float32) self._check_shape_and_dtype(outs, expected)
def odefun_host_callback(state, driver, *args, **kwargs): """ Calls odefun through a host callback in order to make the rest of the ODE solver jit-able. """ result_shape = jax.tree_map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), state.parameters, ) return hcb.call( lambda args_and_kw: odefun(state, driver, *args_and_kw[0], **args_and_kw[1]), # pack args and kwargs together, since host_callback passes a single argument: (args, kwargs), result_shape=result_shape, )
def random_state(hilb: Fock, key, batches: int, *, dtype=np.float32): shape = (batches, hilb.size) # If unconstrained space, use fast sampling if hilb.n_particles is None: rs = jax.random.randint(key, shape=shape, minval=0, maxval=hilb.n_max + 1) return jnp.asarray(rs, dtype=dtype) else: from jax.experimental import host_callback as hcb state = hcb.call( lambda rng: _random_states_with_constraint(hilb, rng, batches, dtype), key, result_shape=jax.ShapeDtypeStruct(shape, dtype), ) return state