def update_fn(updates, state, params=None): # pylint: disable=missing-docstring del params num_vars = len(jax.tree_leaves(updates)) treedef = jax.tree_structure(updates) count_inc = _safe_int32_increment(state.count) variance = eta / count_inc**gamma all_keys = jax.random.split(state.rng_key, num=num_vars + 1) noise = jax.tree_multimap( lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), updates, jax.tree_unflatten(treedef, all_keys[1:])) updates = jax.tree_multimap( lambda g, n: g + variance.astype(g.dtype) * n, updates, noise) return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0])
def join_differentiable(differentiable_xs, non_differentiable_xs): """Reconstitute inputs pytree from differentiable/non-d. partitions.""" differentiable_leaves = list(jax.tree_leaves(differentiable_xs)) non_differentiable_leaves = list(jax.tree_leaves(non_differentiable_xs)) leaves = [] for is_differentiable in jax.tree_leaves(inputs_is_differentiable): if is_differentiable: leaves.append(differentiable_leaves.pop(0)) else: leaves.append(non_differentiable_leaves.pop(0)) assert not differentiable_leaves assert not non_differentiable_leaves return jax.tree_unflatten(jax.tree_structure(inputs), leaves)
def apply_gradient(self, hyper_params, params, state, grads): step = state.step params_flat, treedef = jax.tree_flatten(params) states_flat = treedef.flatten_up_to(state.param_states) grads_flat = treedef.flatten_up_to(grads) # Optionally resize the global gradient to a maximum norm. if hyper_params.grad_norm_clip: grads_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat])) grads_factor = jnp.minimum(1.0, hyper_params.grad_norm_clip / grads_l2) grads_flat = jax.tree_map(lambda param: grads_factor * param, grads_flat) out = [ self.apply_param_gradient(step, hyper_params, param, state, grad) for param, state, grad in zip(params_flat, states_flat, grads_flat) ] new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) new_params = jax.tree_unflatten(treedef, new_params_flat) new_param_states = jax.tree_unflatten(treedef, new_states_flat) new_state = flax.optim.OptimizerState(step + 1, new_param_states) return new_params, new_state
def apply_gradient(self, hyper_params, params, state, grads): """Applies a gradient for a set of parameters. Args: hyper_params: a named tuple of hyper parameters. params: the parameters that should be updated. state: a named tuple containing the state of the optimizer grads: the gradient tensors for the parameters. Returns: A tuple containing the new parameters and the new optimizer state. """ step = state.step params_flat, treedef = jax.tree_flatten(params) states_flat = treedef.flatten_up_to(state.param_states) grads_flat = treedef.flatten_up_to(grads) out = [self.apply_param_gradient(step, hyper_params, param, state, grad) for param, state, grad in zip(params_flat, states_flat, grads_flat)] new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ()) new_params = jax.tree_unflatten(treedef, new_params_flat) new_param_states = jax.tree_unflatten(treedef, new_states_flat) new_state = OptimizerState(step + 1, new_param_states) return new_params, new_state
def _load_leaves_from_zipfile(self, name, tempdir, zipfile, required): fn = name + '.npz' if fn not in zipfile.namelist(): if required: raise IOError(f"required file missing from zipfile: {fn}") return False fp = os.path.join(tempdir, fn) zipfile.extract(fn, tempdir) npzfile = onp.load(fp) treedef = jax.tree_structure(getattr(self, name)) leaves = npzfile.values() pytree = jax.tree_unflatten(treedef, leaves) setattr(self, name, pytree) return True
def wrapped_auto_registered(*args): flat_args, _ = jax.tree_flatten(args) # Mapping from variable -> value env = {} read = functools.partial(read_env, env) write = functools.partial(write_env, env) def tag(var): if matches.get(var) is not None: inv_map, tagging_func = matches[var] var_map = { k: v for k, v in inv_map.items() if not isinstance(k, str) } val_map = jax.tree_map(read, var_map) val = tagging_func(inv_map, val_map) env[var] = val # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, graph.jaxpr.invars, flat_args) jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts) # Register any orphan parameters as generic for param_var in orphan_params: write(param_var, tags.register_generic(read(param_var))) # Set the correct output variables if compute_only_loss_tags: output_vars = loss_output_vars out_tree = jax.tree_structure(loss_output_vars) else: output_vars = graph.jaxpr.outvars out_tree = graph.out_tree # Loop through equations and evaluate primitives using `bind` losses_evaluated = 0 for eqn in graph.jaxpr.eqns: evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) jax_util.safe_map(tag, eqn.outvars) # If we want to output only tagged losses if isinstance(eqn.primitive, tags.LossTag): losses_evaluated += 1 if compute_only_loss_tags and num_losses == losses_evaluated: break outputs = jax_util.safe_map(read, output_vars) return jax.tree_unflatten(out_tree, outputs)
def add_grad_noise(rng, grads, noise_multiplier, l2_norm_clip, global_batch_size): """Adds random noise to the clipped, averaged grads.""" logging.info("We are adding the noise %f", noise_multiplier) grads_flat, grads_treedef = jax.tree_flatten(grads) rngs = jax.random.split(rng, len(grads_flat)) # The grads are already normalized by the batch size (because for DP we # should use loss with normalize_loss_by_num_nonpadding_tokens=True. factor = l2_norm_clip * noise_multiplier / global_batch_size noised_grads = [ g + factor * jax.random.normal(r, g.shape) for r, g in zip(rngs, grads_flat) ] return jax.tree_unflatten(grads_treedef, noised_grads)
def NES_profile_nn(params, params_to_xL, score_function, npop=50, sigma_noise=0.1, alpha=0.05): """Natural Evolutionary strategy Args: params: in orignal pytree form npop: population size sigma: standard deviation alpha: learning rate """ params_flat, treedef = jax.tree_flatten(params) params_shape = [l.shape for l in params_flat] params_size = [l.size for l in params_flat] params_arr = list_to_array(params_flat) num_params = np.sum(np.array(params_size)) N = onp.random.randn(npop, num_params) R = onp.zeros(npop) for j in range(npop): params_try = params_arr.copy() # 1d array params_try = params_try + sigma_noise * N[j] params_try = array_to_list(params_try, params_size, params_shape) # list params_try = jax.tree_unflatten(treedef, params_try) # pytree xL_try = params_to_xL(params_try) R[j] = score_function(xL=xL_try) A = (R - np.mean(R)) / (np.std(R) + 1e-6) params_update = params_arr - alpha / (npop * sigma_noise) * np.dot(N.T, A) params_update = array_to_list(params_update, params_size, params_shape) params_update = jax.tree_unflatten(treedef, params_update) return params_update
def _apply(self, parameters, *inputs, key): flat_inputs, in_tree = tree_flatten(inputs) flat_fun, out_tree = flatten_fun_nokwargs(self._wrapped_fun, in_tree) apply_trace = _top_trace(filter_type=ApplyTrace) with new_master(ApplyTrace) as master: global_parameters_by_primitive = apply_trace.state.global_parameters_by_primitive \ if apply_trace else {} random_state = apply_trace.state.random_state if apply_trace else RandomState( key) master.state = ApplyTraceState(random_state, parameters, global_parameters_by_primitive) flat_outputs = _apply_transform(flat_fun, master).call_wrapped(*flat_inputs) del master return tree_unflatten(out_tree(), flat_outputs)
def scan_fn(broadcast_in, init, *args): xs = jax.tree_multimap(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_multimap( lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_multimap( lambda ax, y: (y if ax is broadcast else ()), out_axes, ys) return broadcast_out, ys else: ys = jax.tree_multimap( lambda ax, y: (() if ax is broadcast else y), out_axes, ys) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) carry_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init) scan_pvals = jax.tree_map( lambda x: pe.PartialVal.unknown( jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs) input_pvals = (carry_pvals, scan_pvals) in_pvals, in_tree = jax.tree_flatten(input_pvals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(broadcast_body), in_tree) _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: if pv is not None: raise ValueError( 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat) c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse) ys = jax.tree_multimap(transpose_from_front, out_axes, ys) ys = jax.tree_multimap( lambda ax, const, y: (const if ax is broadcast else y), out_axes, constants_out, ys) return broadcast_in, c, ys
def restore_state_list(obj, state_list): """Restore model state from a state list. Args: obj: the object that is to be duplicated with the restored state state_list: state as a list of jax.numpy arrays Returns: a copy of `self` with the parameters from state_list loaded >>> restored = restore_state_list(model, state_list) """ state_list = replicate(state_list) structure = jax.tree_util.tree_structure(obj) return jax.tree_unflatten(structure, state_list)
def test_flatten_nested_struct(self): d = { "foo": { "bar": [1, 2, 3] }, "baz": { "bat": [4, 5, 6], "qux": [7, [8, 9]] } } f = FlatMapping(d) leaves, treedef = jax.tree_flatten(f) self.assertEqual([4, 5, 6, 7, 8, 9, 1, 2, 3], leaves) g = jax.tree_unflatten(treedef, leaves) self.assertEqual(g, f) self.assertEqual(g, d)
def named_fun(*args, **kwargs): # Wrap and flatten f for JAX internals. f = lu.wrap_init(fun) flat_args, in_tree = jax.tree_flatten((args, kwargs)) flat_f, out_tree = api.flatten_fun(f, in_tree) # Hide any args that are not a valid JaxType by partially applying flat_f dyn_argnums = [ i for (i, x) in enumerate(flat_args) if jax.api._valid_jaxtype(x) ] # pylint: disable=protected-access part_flat_f, dyn_args = jax.argnums_partial(flat_f, dyn_argnums, flat_args) # Call f with a custom XLA subcomputation via named_call & unflatten result. out_flat = named_call_p.bind(part_flat_f, *dyn_args, name=name) return jax.tree_unflatten(out_tree(), out_flat)
def test_tree_functions(self): f = FlatMapping({"foo": {"b": {"c": 1}, "d": 2}, "bar": {"c": 1}}) m = jax.tree_map(lambda x: x + 1, f) self.assertEqual(type(m), FlatMapping) self.assertEqual(m, {"foo": {"b": {"c": 2}, "d": 3}, "bar": {"c": 2}}) mm = jax.tree_multimap(lambda x, y: x + y, f, f) self.assertEqual(type(mm), FlatMapping) self.assertEqual(mm, {"foo": {"b": {"c": 2}, "d": 4}, "bar": {"c": 2}}) leaves, treedef = jax.tree_flatten(f) self.assertEqual(leaves, [1, 1, 2]) uf = jax.tree_unflatten(treedef, leaves) self.assertEqual(type(f), FlatMapping) self.assertEqual(f, uf)
def tree_map_zipped(fn: Callable[..., Any], nests: Sequence[Any]): """Map a function over a list of identical nested structures. Args: fn: the function to map; must have arity equal to `len(list_of_nests)`. nests: a list of identical nested structures. Returns: a nested structure whose leaves are outputs of applying `fn`. """ if not nests: return nests tree_def = tree_structure(nests[0]) if any([tree_structure(x) != tree_def for x in nests[1:]]): raise ValueError('All elements must share the same tree structure.') return jax.tree_unflatten( tree_def, [fn(*d) for d in zip(*[jax.tree_leaves(x) for x in nests])])
def wrapper(*args, **kwargs): side_channel = {"non_jaxtypes": [], "treedef": None} wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel) if base.inside_transform(): wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)( wrapped_fun, name=name) else: wrapped_fun = jax.named_call(wrapped_fun, name=name) jax_types = wrapped_fun(*args, **kwargs) non_jaxtypes = side_channel["non_jaxtypes"] out_leaves = [ y if x is None else x for x, y in zip(jax_types, non_jaxtypes) ] out = jax.tree_unflatten(side_channel["treedef"], out_leaves) return out
def objective(parameters_vec): """Objective function for minimization. Args: parameters_vec: Float numpy array with shape (num_parameters,), the parameters for exchange-correlation functional. Returns: Float, the weighted root mean square deviation (WRMSD). """ loss = float( eval_wrmsd(**jax.tree_unflatten(functional.parameters_spec, parameters_vec))) if self.l1_penalty > 1e-8: loss += self.l1_penalty * np.sum(np.abs(parameters_vec)) if self.l2_penalty > 1e-8: loss += self.l2_penalty * np.sum(parameters_vec**2) return loss
def _wrap_and_compile(self, signature, args_flat, in_tree): """Compiles the function for the given signature.""" def wrapped_function(*args_flat): args, kwargs = jax.tree_unflatten(in_tree, args_flat) return self._function(*args, **kwargs) # Compile the wrapped_function to IREE. binary = aot(wrapped_function, *args_flat, **self._options) cpp_vm_module = rt.VmModule.from_flatbuffer(binary) module = rt.load_module(cpp_vm_module, config=self._driver_config) # Get the output tree so it can be reconstructed from the outputs of the # compiled module. Duplicating execution here isn't ideal, and could # probably be avoided using internal APIs. args, kwargs = jax.tree_unflatten(in_tree, args_flat) _, out_tree = jax.tree_flatten(self._function(*args, **kwargs)) self._memoized_signatures[signature] = (binary, module, out_tree)
def parallel_read(old, fname): old_val, treedef = jax.tree_flatten(old) with open(fname, "rb") as f: buf = f.read() f_io = io.BytesIO(buf) loaded = np.load(f_io) new_vals = [] for i in loaded: new_vals.append(loaded[i]) for o, n in zip(new_vals, old_val): assert o.shape == n.shape, "Incompatible checkpoint" if o.dtype == np.dtype('V2'): o.dtype = jnp.bfloat16 return jax.tree_unflatten(treedef, new_vals)
def merged_func(*func_args): typed_jaxpr, out_avals = jax.make_jaxpr(f, return_shape=True)(*func_args) out_tree = jax.tree_structure(out_avals) jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals # Mapping from variable -> value env = dict() read = functools.partial(read_env, env) write = functools.partial(write_env, env) # Bind args and consts to environment flat_args = jax.tree_flatten(func_args)[0] write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, flat_args) jax_util.safe_map(write, jaxpr.constvars, consts) # Bind args and consts to environment write(jax.core.unitvar, jax.core.unit) jax_util.safe_map(write, jaxpr.invars, flat_args) jax_util.safe_map(write, jaxpr.constvars, consts) # Loop through equations and evaluate primitives using `bind` broadcasts_outputs = dict() for eqn in clean_jaxpr_eqns(jaxpr): # We ignore broadcasting of constants if (eqn.primitive.name == "broadcast_in_dim" and not all( isinstance(v, jax_core.Literal) for v in eqn.invars)): if eqn.invars[0] in broadcasts_outputs: x, dims = broadcasts_outputs[eqn.invars[0]] kept_dims = eqn.params["broadcast_dimensions"] kept_dims = [kept_dims[d] for d in dims] y = lax.broadcast_in_dim(x, eqn.params["shape"], kept_dims) jax_util.safe_map(write, eqn.outvars, [y]) broadcasts_outputs[eqn.outvars[0]] = (x, kept_dims) else: inputs = jax_util.safe_map(read, eqn.invars) evaluate_eqn(eqn, inputs, write) broadcasts_outputs[eqn.outvars[0]] = ( inputs[0], eqn.params["broadcast_dimensions"]) else: evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write) return jax.tree_unflatten(out_tree, jax_util.safe_map(read, jaxpr.outvars))
def update_fn(updates, state, params=None): del params grads_flat, grads_treedef = jax.tree_flatten(updates) bsize = grads_flat[0].shape[0] if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat): raise ValueError( 'Unlike other transforms, `differentially_private_aggregate` expects' ' `updates` to have a batch dimension in the 0th axis. That is, this' ' function expects per-example gradients as input.') new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat)+1) global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads_flat) divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat] noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize for g, r in zip(clipped, rngs)] return (jax.tree_unflatten(grads_treedef, noised), DifferentiallyPrivateAggregateState(rng_key=new_key))
def partial_eval_by_shape(fn, input_spec, *args, **kwargs): """Lazily evaluate a function by using the shapes of the inputs. This function is similar to `jax.eval_shape` with the key difference that function outputs that can be computed without a concrete value of the inputs are returned as is instead of only the shape. See for example `module.init_by_shape` where this functionality is used to initialize a model without using input data lr computation. Args: fn: the function to be lazily evaluated. input_spec: an iterable of shapes or (shape, dtype) tuples specifying the shape and type of the inputs. If unspecified the dtype is float32. *args: other arguments passed to the module's apply function **kwargs: keyword arguments passed to the module's apply function Returns: A pair consisting of the model output and an instance of Model """ # output cannot be returned in lazy_create because jax.eval_shape will only # return the shape and dtype. # TODO(mattjj,jheek): use a public JAX API f = lambda *inputs: fn(*inputs, *args, **kwargs) input_structs = [_parse_spec(spec) for spec in input_spec] inputs_flat, in_tree = jax.tree_flatten(input_structs) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( lu.wrap_init(f), in_tree) in_pvals = [ pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)) for x in inputs_flat ] if config.omnistaging_enabled: _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) else: with jax.core.initial_style_staging(): _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) out_flat = [ const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype) for pv, const in out_pvals ] return jax.tree_unflatten(out_tree(), out_flat)
def vjp_func(tangents): flat_tangents = jax.tree_flatten(tangents)[0] loss_invars = [] loss_targets = [] for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs): num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs])) loss_targets.append(inputs[num_inputs:]) treedef = jax.tree_structure(loss_invars) tangents = jax.tree_unflatten(treedef, flat_tangents) # Since the losses could also take and targets as inputs and we don't want # this function to computes vjp w.r.t to those (e.g. the user should not # be providing tangent vectors for the targets, only for inputs) we have # to manually fill in these "extra" tangents with zeros. targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets) tangents = tuple(ti + tti for ti, tti in zip(tangents, targets_tangents)) input_tangents = full_vjp_func(tangents)[0] return input_tangents,
def __init__(self, *args, **kwargs): """Accepts FlatComponents or the same arguments as `dict`.""" if not kwargs and len(args) == 1 and type(args[0]) is FlatComponents: leaves, structure = args[0] mapping = None # When unflattening we cannot assume that the leaves are not pytrees (for # example: `jax.tree_map(list, my_map)` would pass a list of lists in # as leaves). if not jax.tree_util.all_leaves(leaves): mapping = jax.tree_unflatten(structure, leaves) leaves, structure = jax.tree_flatten(mapping) else: mapping = dict(*args, **kwargs) leaves, structure = jax.tree_flatten(mapping) self._structure = structure self._leaves = tuple(leaves) self._mapping = mapping
def unpack(bundle): per_example = list(bundle.per_example) per_head = list(bundle.per_head) per_example_per_head = list(bundle.per_example_per_head) singleton = list(bundle.singleton) leaves = [] for is_per_example, is_per_head in zip(has_batch_dim, has_head_dim): if is_per_example and not is_per_head: leaves.append(per_example.pop(0)) elif not is_per_example and is_per_head: leaves.append(per_head.pop(0)) elif is_per_example and is_per_head: leaves.append(per_example_per_head.pop(0)) else: leaves.append(singleton.pop(0)) assert (not per_example) and (not per_head) assert (not per_example_per_head) and (not singleton) return jax.tree_unflatten(treedef, leaves)
def __init__(self, flat: FlatComponents, check_leaves: bool = True): """Constructs a flat mapping from already flattened components. Args: flat: A tuple containing a flat sequence of values and a PyTreeDef representing the output of jax.tree_flatten on a structure. check_leaves: Check if all leaves are flat values, and reflatten if not. This check is O(n), whereas the normal construction time is O(1). """ leaves, structure = flat # TODO(lenamartens): upstream is_leaf check to Jax is_leaf = lambda x: type(x) not in tree_util._registry # pylint: disable=unidiomatic-typecheck pylint: disable=protected-access if check_leaves and not all(map(is_leaf, leaves)): tree = jax.tree_unflatten(structure, leaves) leaves, structure = jax.tree_flatten(tree) self._structure = structure self._leaves = tuple(leaves) self._mapping = None
def __call__(self, *args, **kwargs): """Executes the function on the provided inputs, compiling if necessary.""" args_flat, _ = jax.tree_flatten((args, kwargs)) # Use the uncompiled function if the inputs are being traced. if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat): return self._function(*args, **kwargs) module, out_tree = self._get_compiled_artifacts(args, kwargs) results = module.main(*args_flat) if results is not None: if not isinstance(results, tuple): results = (results, ) return jax.tree_unflatten(out_tree, results) else: # Address IREE returning None instead of empty sequences. if out_tree == jax.tree_flatten([])[-1]: return [] elif out_tree == jax.tree_flatten(())[-1]: return () else: return results
def get_fingerprint(self, num_feature_samples=10, num_parameter_samples=10, num_decimals=5): """Gets a fingerprint for the functional. Fingerprint is evaluated as the MD5 hash value of functional singatures on a random sample of feature and parameter values. Signatures will be converted to strings with specified number of decimals. Args: num_feature_samples: Integer, number of samples of features. num_parameter_samples: Integer, number of samples of parameters. num_decimals: Integer, number of decimals when converting signatures to strings. Returns: String, the fingerprint of the functional. """ format_string = f'{{:.{num_decimals}f}}' # fix random seed to have consistent sampling behavior when running the # code on different distributed workers random_state = np.random.RandomState(0) parameter_samples = random_state.rand(num_parameter_samples, self.num_parameters) signatures = [] for parameters in parameter_samples: signatures.extend( self.get_signature(**jax.tree_unflatten( self.parameters_spec, parameters), num_feature_samples=num_feature_samples, random_state=random_state, signature='e_xc')) signature_string = ','.join(map(format_string.format, signatures)) return hashlib.md5(signature_string.encode('utf-8')).hexdigest()
def _local_value_and_grad_notcentered_kernel(logpsi, pars, vp, mel, v): odtype = outdtype(logpsi, pars, v) # can use if with jit because that argument is exposed statically to the jit! # if real_to_complex: if not tree_leaf_iscomplex(pars) and jnp.issubdtype( odtype, jnp.complexfloating): logpsi_vp_r, f_vjp_r = jax.vjp(lambda w: (logpsi(w, vp).real), pars) logpsi_vp_j, f_vjp_j = jax.vjp(lambda w: (logpsi(w, vp).imag), pars) logpsi_vp = logpsi_vp_r + 1.0j * logpsi_vp_j vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v)) vec = jnp.asarray(vec, dtype=odtype) vec_r = vec.real vec_j = vec.imag loc_val = vec.sum() vr_grad_r, tree_fun = jax.tree_flatten(f_vjp_r(vec_r)[0]) vj_grad_r, _ = jax.tree_flatten(f_vjp_r(vec_j)[0]) vr_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_r)[0]) vj_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_j)[0]) r_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_r, vj_grad_r)] j_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_j, vj_grad_j)] out_flat = [re + 1.0j * im for re, im in zip(r_flat, j_flat)] grad_c = jax.tree_unflatten(tree_fun, out_flat) else: logpsi_vp, f_vjp = jax.vjp(lambda w: logpsi(w, vp), pars) vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v)) vec = jnp.asarray(vec, dtype=odtype) loc_val = vec.sum() grad_c = f_vjp(vec)[0] return loc_val, grad_c
def sample_simulation_parameters(simulation_parameter_ranges, num_trajectories, rng): """Samples simulation parameters.""" is_tuple = lambda val: isinstance(val, tuple) ranges_flat, ranges_treedef = jax.tree_flatten(simulation_parameter_ranges, is_leaf=is_tuple) rng, shuffle_rng, *rngs = jax.random.split(rng, len(ranges_flat) + 2) shuffle_indices = jax.random.permutation(shuffle_rng, num_trajectories) rng_tree = jax.tree_unflatten(ranges_treedef, rngs) def sample_simulation_parameter(simulation_parameter_range, parameter_rng): """Sample a single simulation parameter.""" del parameter_rng minval, maxval = simulation_parameter_range samples = jnp.linspace(minval, maxval, num=num_trajectories) return jnp.sort(samples)[shuffle_indices] return jax.tree_map(sample_simulation_parameter, simulation_parameter_ranges, rng_tree, is_leaf=is_tuple)