Esempio n. 1
0
 def update_fn(updates, state, params=None):  # pylint: disable=missing-docstring
     del params
     num_vars = len(jax.tree_leaves(updates))
     treedef = jax.tree_structure(updates)
     count_inc = _safe_int32_increment(state.count)
     variance = eta / count_inc**gamma
     all_keys = jax.random.split(state.rng_key, num=num_vars + 1)
     noise = jax.tree_multimap(
         lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype),
         updates, jax.tree_unflatten(treedef, all_keys[1:]))
     updates = jax.tree_multimap(
         lambda g, n: g + variance.astype(g.dtype) * n, updates, noise)
     return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0])
Esempio n. 2
0
 def join_differentiable(differentiable_xs, non_differentiable_xs):
   """Reconstitute inputs pytree from differentiable/non-d. partitions."""
   differentiable_leaves = list(jax.tree_leaves(differentiable_xs))
   non_differentiable_leaves = list(jax.tree_leaves(non_differentiable_xs))
   leaves = []
   for is_differentiable in jax.tree_leaves(inputs_is_differentiable):
     if is_differentiable:
       leaves.append(differentiable_leaves.pop(0))
     else:
       leaves.append(non_differentiable_leaves.pop(0))
   assert not differentiable_leaves
   assert not non_differentiable_leaves
   return jax.tree_unflatten(jax.tree_structure(inputs), leaves)
Esempio n. 3
0
  def apply_gradient(self, hyper_params, params, state, grads):
    step = state.step
    params_flat, treedef = jax.tree_flatten(params)
    states_flat = treedef.flatten_up_to(state.param_states)
    grads_flat = treedef.flatten_up_to(grads)

    # Optionally resize the global gradient to a maximum norm.
    if hyper_params.grad_norm_clip:
      grads_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
      grads_factor = jnp.minimum(1.0, hyper_params.grad_norm_clip / grads_l2)
      grads_flat = jax.tree_map(lambda param: grads_factor * param, grads_flat)

    out = [
        self.apply_param_gradient(step, hyper_params, param, state, grad)
        for param, state, grad in zip(params_flat, states_flat, grads_flat)
    ]

    new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
    new_params = jax.tree_unflatten(treedef, new_params_flat)
    new_param_states = jax.tree_unflatten(treedef, new_states_flat)
    new_state = flax.optim.OptimizerState(step + 1, new_param_states)
    return new_params, new_state
Esempio n. 4
0
  def apply_gradient(self, hyper_params, params, state, grads):
    """Applies a gradient for a set of parameters.

    Args:
      hyper_params: a named tuple of hyper parameters.
      params: the parameters that should be updated.
      state: a named tuple containing the state of the optimizer
      grads: the gradient tensors for the parameters.
    Returns:
      A tuple containing the new parameters and the new optimizer state.
    """
    step = state.step
    params_flat, treedef = jax.tree_flatten(params)
    states_flat = treedef.flatten_up_to(state.param_states)
    grads_flat = treedef.flatten_up_to(grads)
    out = [self.apply_param_gradient(step, hyper_params, param, state, grad)
           for param, state, grad in zip(params_flat, states_flat, grads_flat)]

    new_params_flat, new_states_flat = list(zip(*out)) if out else ((), ())
    new_params = jax.tree_unflatten(treedef, new_params_flat)
    new_param_states = jax.tree_unflatten(treedef, new_states_flat)
    new_state = OptimizerState(step + 1, new_param_states)
    return new_params, new_state
Esempio n. 5
0
 def _load_leaves_from_zipfile(self, name, tempdir, zipfile, required):
     fn = name + '.npz'
     if fn not in zipfile.namelist():
         if required:
             raise IOError(f"required file missing from zipfile: {fn}")
         return False
     fp = os.path.join(tempdir, fn)
     zipfile.extract(fn, tempdir)
     npzfile = onp.load(fp)
     treedef = jax.tree_structure(getattr(self, name))
     leaves = npzfile.values()
     pytree = jax.tree_unflatten(treedef, leaves)
     setattr(self, name, pytree)
     return True
Esempio n. 6
0
    def wrapped_auto_registered(*args):
        flat_args, _ = jax.tree_flatten(args)
        # Mapping from variable -> value
        env = {}

        read = functools.partial(read_env, env)
        write = functools.partial(write_env, env)

        def tag(var):
            if matches.get(var) is not None:
                inv_map, tagging_func = matches[var]
                var_map = {
                    k: v
                    for k, v in inv_map.items() if not isinstance(k, str)
                }
                val_map = jax.tree_map(read, var_map)
                val = tagging_func(inv_map, val_map)
                env[var] = val

        # Bind args and consts to environment
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, graph.jaxpr.invars, flat_args)
        jax_util.safe_map(write, graph.jaxpr.constvars, graph.consts)

        # Register any orphan parameters as generic
        for param_var in orphan_params:
            write(param_var, tags.register_generic(read(param_var)))

        # Set the correct output variables
        if compute_only_loss_tags:
            output_vars = loss_output_vars
            out_tree = jax.tree_structure(loss_output_vars)
        else:
            output_vars = graph.jaxpr.outvars
            out_tree = graph.out_tree

        # Loop through equations and evaluate primitives using `bind`
        losses_evaluated = 0
        for eqn in graph.jaxpr.eqns:
            evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
            jax_util.safe_map(tag, eqn.outvars)

            # If we want to output only tagged losses
            if isinstance(eqn.primitive, tags.LossTag):
                losses_evaluated += 1
            if compute_only_loss_tags and num_losses == losses_evaluated:
                break

        outputs = jax_util.safe_map(read, output_vars)
        return jax.tree_unflatten(out_tree, outputs)
Esempio n. 7
0
def add_grad_noise(rng, grads, noise_multiplier, l2_norm_clip,
                   global_batch_size):
  """Adds random noise to the clipped, averaged grads."""
  logging.info("We are adding the noise %f", noise_multiplier)
  grads_flat, grads_treedef = jax.tree_flatten(grads)
  rngs = jax.random.split(rng, len(grads_flat))
  # The grads are already normalized by the batch size (because for DP we
  # should use loss with normalize_loss_by_num_nonpadding_tokens=True.
  factor = l2_norm_clip * noise_multiplier / global_batch_size
  noised_grads = [
      g + factor * jax.random.normal(r, g.shape)
      for r, g in zip(rngs, grads_flat)
  ]
  return jax.tree_unflatten(grads_treedef, noised_grads)
Esempio n. 8
0
def NES_profile_nn(params,
                   params_to_xL,
                   score_function,
                   npop=50,
                   sigma_noise=0.1,
                   alpha=0.05):
    """Natural Evolutionary strategy
  
  Args:
  		params: in orignal pytree form
  		npop: population size
  		sigma: standard deviation
  		alpha: learning rate
  """

    params_flat, treedef = jax.tree_flatten(params)
    params_shape = [l.shape for l in params_flat]
    params_size = [l.size for l in params_flat]
    params_arr = list_to_array(params_flat)

    num_params = np.sum(np.array(params_size))
    N = onp.random.randn(npop, num_params)
    R = onp.zeros(npop)
    for j in range(npop):
        params_try = params_arr.copy()  # 1d array
        params_try = params_try + sigma_noise * N[j]
        params_try = array_to_list(params_try, params_size,
                                   params_shape)  # list
        params_try = jax.tree_unflatten(treedef, params_try)  # pytree
        xL_try = params_to_xL(params_try)
        R[j] = score_function(xL=xL_try)
    A = (R - np.mean(R)) / (np.std(R) + 1e-6)
    params_update = params_arr - alpha / (npop * sigma_noise) * np.dot(N.T, A)
    params_update = array_to_list(params_update, params_size, params_shape)
    params_update = jax.tree_unflatten(treedef, params_update)

    return params_update
Esempio n. 9
0
 def _apply(self, parameters, *inputs, key):
     flat_inputs, in_tree = tree_flatten(inputs)
     flat_fun, out_tree = flatten_fun_nokwargs(self._wrapped_fun, in_tree)
     apply_trace = _top_trace(filter_type=ApplyTrace)
     with new_master(ApplyTrace) as master:
         global_parameters_by_primitive = apply_trace.state.global_parameters_by_primitive \
             if apply_trace else {}
         random_state = apply_trace.state.random_state if apply_trace else RandomState(
             key)
         master.state = ApplyTraceState(random_state, parameters,
                                        global_parameters_by_primitive)
         flat_outputs = _apply_transform(flat_fun,
                                         master).call_wrapped(*flat_inputs)
         del master
     return tree_unflatten(out_tree(), flat_outputs)
Esempio n. 10
0
    def scan_fn(broadcast_in, init, *args):
        xs = jax.tree_multimap(transpose_to_front, in_axes, args)

        def body_fn(c, xs, init_mode=False):
            # inject constants
            xs = jax.tree_multimap(
                lambda ax, arg, x: (arg if ax is broadcast else x), in_axes,
                args, xs)
            broadcast_out, c, ys = fn(broadcast_in, c, *xs)

            if init_mode:
                ys = jax.tree_multimap(
                    lambda ax, y: (y if ax is broadcast else ()), out_axes, ys)
                return broadcast_out, ys
            else:
                ys = jax.tree_multimap(
                    lambda ax, y: (() if ax is broadcast else y), out_axes, ys)
                return c, ys

        broadcast_body = functools.partial(body_fn, init_mode=True)

        carry_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), init)
        scan_pvals = jax.tree_map(
            lambda x: pe.PartialVal.unknown(
                jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), xs)
        input_pvals = (carry_pvals, scan_pvals)
        in_pvals, in_tree = jax.tree_flatten(input_pvals)
        f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
            lu.wrap_init(broadcast_body), in_tree)
        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)

        out_flat = []
        for pv, const in out_pvals:
            if pv is not None:
                raise ValueError(
                    'broadcasted variable has a data dependency on the scan body.'
                )
            out_flat.append(const)
        broadcast_in, constants_out = jax.tree_unflatten(out_tree(), out_flat)

        c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse)
        ys = jax.tree_multimap(transpose_from_front, out_axes, ys)
        ys = jax.tree_multimap(
            lambda ax, const, y: (const if ax is broadcast else y), out_axes,
            constants_out, ys)
        return broadcast_in, c, ys
Esempio n. 11
0
def restore_state_list(obj, state_list):
  """Restore model state from a state list.

  Args:
    obj: the object that is to be duplicated with the
      restored state
    state_list: state as a list of jax.numpy arrays

  Returns:
    a copy of `self` with the parameters from state_list loaded

  >>> restored = restore_state_list(model, state_list)
  """
  state_list = replicate(state_list)
  structure = jax.tree_util.tree_structure(obj)
  return jax.tree_unflatten(structure, state_list)
Esempio n. 12
0
 def test_flatten_nested_struct(self):
     d = {
         "foo": {
             "bar": [1, 2, 3]
         },
         "baz": {
             "bat": [4, 5, 6],
             "qux": [7, [8, 9]]
         }
     }
     f = FlatMapping(d)
     leaves, treedef = jax.tree_flatten(f)
     self.assertEqual([4, 5, 6, 7, 8, 9, 1, 2, 3], leaves)
     g = jax.tree_unflatten(treedef, leaves)
     self.assertEqual(g, f)
     self.assertEqual(g, d)
Esempio n. 13
0
    def named_fun(*args, **kwargs):
        # Wrap and flatten f for JAX internals.
        f = lu.wrap_init(fun)
        flat_args, in_tree = jax.tree_flatten((args, kwargs))
        flat_f, out_tree = api.flatten_fun(f, in_tree)

        # Hide any args that are not a valid JaxType by partially applying flat_f
        dyn_argnums = [
            i for (i, x) in enumerate(flat_args) if jax.api._valid_jaxtype(x)
        ]  # pylint: disable=protected-access
        part_flat_f, dyn_args = jax.argnums_partial(flat_f, dyn_argnums,
                                                    flat_args)

        # Call f with a custom XLA subcomputation via named_call & unflatten result.
        out_flat = named_call_p.bind(part_flat_f, *dyn_args, name=name)
        return jax.tree_unflatten(out_tree(), out_flat)
Esempio n. 14
0
    def test_tree_functions(self):
        f = FlatMapping({"foo": {"b": {"c": 1}, "d": 2}, "bar": {"c": 1}})

        m = jax.tree_map(lambda x: x + 1, f)
        self.assertEqual(type(m), FlatMapping)
        self.assertEqual(m, {"foo": {"b": {"c": 2}, "d": 3}, "bar": {"c": 2}})

        mm = jax.tree_multimap(lambda x, y: x + y, f, f)
        self.assertEqual(type(mm), FlatMapping)
        self.assertEqual(mm, {"foo": {"b": {"c": 2}, "d": 4}, "bar": {"c": 2}})

        leaves, treedef = jax.tree_flatten(f)
        self.assertEqual(leaves, [1, 1, 2])
        uf = jax.tree_unflatten(treedef, leaves)
        self.assertEqual(type(f), FlatMapping)
        self.assertEqual(f, uf)
Esempio n. 15
0
def tree_map_zipped(fn: Callable[..., Any], nests: Sequence[Any]):
    """Map a function over a list of identical nested structures.

  Args:
    fn: the function to map; must have arity equal to `len(list_of_nests)`.
    nests: a list of identical nested structures.

  Returns:
    a nested structure whose leaves are outputs of applying `fn`.
  """
    if not nests:
        return nests
    tree_def = tree_structure(nests[0])
    if any([tree_structure(x) != tree_def for x in nests[1:]]):
        raise ValueError('All elements must share the same tree structure.')
    return jax.tree_unflatten(
        tree_def, [fn(*d) for d in zip(*[jax.tree_leaves(x) for x in nests])])
Esempio n. 16
0
    def wrapper(*args, **kwargs):
        side_channel = {"non_jaxtypes": [], "treedef": None}
        wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel)
        if base.inside_transform():
            wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)(
                wrapped_fun, name=name)
        else:
            wrapped_fun = jax.named_call(wrapped_fun, name=name)

        jax_types = wrapped_fun(*args, **kwargs)

        non_jaxtypes = side_channel["non_jaxtypes"]
        out_leaves = [
            y if x is None else x for x, y in zip(jax_types, non_jaxtypes)
        ]
        out = jax.tree_unflatten(side_channel["treedef"], out_leaves)

        return out
Esempio n. 17
0
        def objective(parameters_vec):
            """Objective function for minimization.

      Args:
        parameters_vec: Float numpy array with shape (num_parameters,), the
          parameters for exchange-correlation functional.

      Returns:
        Float, the weighted root mean square deviation (WRMSD).
      """
            loss = float(
                eval_wrmsd(**jax.tree_unflatten(functional.parameters_spec,
                                                parameters_vec)))
            if self.l1_penalty > 1e-8:
                loss += self.l1_penalty * np.sum(np.abs(parameters_vec))
            if self.l2_penalty > 1e-8:
                loss += self.l2_penalty * np.sum(parameters_vec**2)
            return loss
Esempio n. 18
0
    def _wrap_and_compile(self, signature, args_flat, in_tree):
        """Compiles the function for the given signature."""
        def wrapped_function(*args_flat):
            args, kwargs = jax.tree_unflatten(in_tree, args_flat)
            return self._function(*args, **kwargs)

        # Compile the wrapped_function to IREE.
        binary = aot(wrapped_function, *args_flat, **self._options)
        cpp_vm_module = rt.VmModule.from_flatbuffer(binary)
        module = rt.load_module(cpp_vm_module, config=self._driver_config)

        # Get the output tree so it can be reconstructed from the outputs of the
        # compiled module. Duplicating execution here isn't ideal, and could
        # probably be avoided using internal APIs.
        args, kwargs = jax.tree_unflatten(in_tree, args_flat)
        _, out_tree = jax.tree_flatten(self._function(*args, **kwargs))

        self._memoized_signatures[signature] = (binary, module, out_tree)
Esempio n. 19
0
def parallel_read(old, fname):
    old_val, treedef = jax.tree_flatten(old)
    with open(fname, "rb") as f:
        buf = f.read()
        f_io = io.BytesIO(buf)
        loaded = np.load(f_io)

    new_vals = []
    for i in loaded:
        new_vals.append(loaded[i])

    for o, n in zip(new_vals, old_val):
        assert o.shape == n.shape, "Incompatible checkpoint"

        if o.dtype == np.dtype('V2'):
            o.dtype = jnp.bfloat16

    return jax.tree_unflatten(treedef, new_vals)
Esempio n. 20
0
    def merged_func(*func_args):
        typed_jaxpr, out_avals = jax.make_jaxpr(f,
                                                return_shape=True)(*func_args)
        out_tree = jax.tree_structure(out_avals)
        jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals

        # Mapping from variable -> value
        env = dict()
        read = functools.partial(read_env, env)
        write = functools.partial(write_env, env)

        # Bind args and consts to environment
        flat_args = jax.tree_flatten(func_args)[0]
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, jaxpr.invars, flat_args)
        jax_util.safe_map(write, jaxpr.constvars, consts)

        # Bind args and consts to environment
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, jaxpr.invars, flat_args)
        jax_util.safe_map(write, jaxpr.constvars, consts)

        # Loop through equations and evaluate primitives using `bind`
        broadcasts_outputs = dict()
        for eqn in clean_jaxpr_eqns(jaxpr):
            # We ignore broadcasting of constants
            if (eqn.primitive.name == "broadcast_in_dim" and not all(
                    isinstance(v, jax_core.Literal) for v in eqn.invars)):
                if eqn.invars[0] in broadcasts_outputs:
                    x, dims = broadcasts_outputs[eqn.invars[0]]
                    kept_dims = eqn.params["broadcast_dimensions"]
                    kept_dims = [kept_dims[d] for d in dims]
                    y = lax.broadcast_in_dim(x, eqn.params["shape"], kept_dims)
                    jax_util.safe_map(write, eqn.outvars, [y])
                    broadcasts_outputs[eqn.outvars[0]] = (x, kept_dims)
                else:
                    inputs = jax_util.safe_map(read, eqn.invars)
                    evaluate_eqn(eqn, inputs, write)
                    broadcasts_outputs[eqn.outvars[0]] = (
                        inputs[0], eqn.params["broadcast_dimensions"])
            else:
                evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
        return jax.tree_unflatten(out_tree,
                                  jax_util.safe_map(read, jaxpr.outvars))
Esempio n. 21
0
  def update_fn(updates, state, params=None):
    del params
    grads_flat, grads_treedef = jax.tree_flatten(updates)
    bsize = grads_flat[0].shape[0]

    if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat):
      raise ValueError(
          'Unlike other transforms, `differentially_private_aggregate` expects'
          ' `updates` to have a batch dimension in the 0th axis. That is, this'
          ' function expects per-example gradients as input.')

    new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat)+1)
    global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads_flat)
    divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
    clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat]
    noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize
              for g, r in zip(clipped, rngs)]
    return (jax.tree_unflatten(grads_treedef, noised),
            DifferentiallyPrivateAggregateState(rng_key=new_key))
Esempio n. 22
0
def partial_eval_by_shape(fn, input_spec, *args, **kwargs):
    """Lazily evaluate a function by using the shapes of the inputs.

  This function is similar to `jax.eval_shape` with the key difference that
  function outputs that can be computed without a concrete value of the
  inputs are returned as is instead of only the shape. See for example
  `module.init_by_shape` where this functionality is used to initialize a
  model without using input data lr computation.

  Args:
    fn: the function to be lazily evaluated.
    input_spec: an iterable of shapes or (shape, dtype) tuples specifying the
      shape and type of the inputs. If unspecified the dtype is float32.
    *args: other arguments passed to the module's apply function
    **kwargs: keyword arguments passed to the module's apply function
  Returns:
    A pair consisting of the model output and an instance of Model
  """
    # output cannot be returned in lazy_create because jax.eval_shape will only
    # return the shape and dtype.
    # TODO(mattjj,jheek): use a public JAX API
    f = lambda *inputs: fn(*inputs, *args, **kwargs)
    input_structs = [_parse_spec(spec) for spec in input_spec]
    inputs_flat, in_tree = jax.tree_flatten(input_structs)
    f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
        lu.wrap_init(f), in_tree)
    in_pvals = [
        pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
        for x in inputs_flat
    ]

    if config.omnistaging_enabled:
        _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
    else:
        with jax.core.initial_style_staging():
            _, out_pvals, _ = pe.trace_to_jaxpr(f_flat,
                                                in_pvals,
                                                stage_out=True)
    out_flat = [
        const if pv is None else jax.ShapeDtypeStruct(pv.shape, pv.dtype)
        for pv, const in out_pvals
    ]
    return jax.tree_unflatten(out_tree(), out_flat)
Esempio n. 23
0
 def vjp_func(tangents):
     flat_tangents = jax.tree_flatten(tangents)[0]
     loss_invars = []
     loss_targets = []
     for jaxpr_eqn, inputs in zip(loss_jaxpr_eqns, losses_inputs):
         num_inputs = _unbox_loss_tag(jaxpr_eqn).num_inputs
         loss_invars.append(tuple(jaxpr_eqn.invars[:num_inputs]))
         loss_targets.append(inputs[num_inputs:])
     treedef = jax.tree_structure(loss_invars)
     tangents = jax.tree_unflatten(treedef, flat_tangents)
     # Since the losses could also take and targets as inputs and we don't want
     # this function to computes vjp w.r.t to those (e.g. the user should not
     # be providing tangent vectors for the targets, only for inputs) we have
     # to manually fill in these "extra" tangents with zeros.
     targets_tangents = jax.tree_map(jnp.zeros_like, loss_targets)
     tangents = tuple(ti + tti
                      for ti, tti in zip(tangents, targets_tangents))
     input_tangents = full_vjp_func(tangents)[0]
     return input_tangents,
Esempio n. 24
0
    def __init__(self, *args, **kwargs):
        """Accepts FlatComponents or the same arguments as `dict`."""
        if not kwargs and len(args) == 1 and type(args[0]) is FlatComponents:
            leaves, structure = args[0]
            mapping = None

            # When unflattening we cannot assume that the leaves are not pytrees (for
            # example: `jax.tree_map(list, my_map)` would pass a list of lists in
            # as leaves).
            if not jax.tree_util.all_leaves(leaves):
                mapping = jax.tree_unflatten(structure, leaves)
                leaves, structure = jax.tree_flatten(mapping)
        else:
            mapping = dict(*args, **kwargs)
            leaves, structure = jax.tree_flatten(mapping)

        self._structure = structure
        self._leaves = tuple(leaves)
        self._mapping = mapping
Esempio n. 25
0
 def unpack(bundle):
     per_example = list(bundle.per_example)
     per_head = list(bundle.per_head)
     per_example_per_head = list(bundle.per_example_per_head)
     singleton = list(bundle.singleton)
     leaves = []
     for is_per_example, is_per_head in zip(has_batch_dim,
                                            has_head_dim):
         if is_per_example and not is_per_head:
             leaves.append(per_example.pop(0))
         elif not is_per_example and is_per_head:
             leaves.append(per_head.pop(0))
         elif is_per_example and is_per_head:
             leaves.append(per_example_per_head.pop(0))
         else:
             leaves.append(singleton.pop(0))
     assert (not per_example) and (not per_head)
     assert (not per_example_per_head) and (not singleton)
     return jax.tree_unflatten(treedef, leaves)
Esempio n. 26
0
  def __init__(self, flat: FlatComponents, check_leaves: bool = True):
    """Constructs a flat mapping from already flattened components.

    Args:
      flat: A tuple containing a flat sequence of values and a PyTreeDef
      representing the output of jax.tree_flatten on a structure.
      check_leaves: Check if all leaves are flat values, and reflatten if not.
      This check is O(n), whereas the normal construction time is O(1).
    """
    leaves, structure = flat

    # TODO(lenamartens): upstream is_leaf check to Jax
    is_leaf = lambda x: type(x) not in tree_util._registry  # pylint: disable=unidiomatic-typecheck  pylint: disable=protected-access
    if check_leaves and not all(map(is_leaf, leaves)):
      tree = jax.tree_unflatten(structure, leaves)
      leaves, structure = jax.tree_flatten(tree)

    self._structure = structure
    self._leaves = tuple(leaves)
    self._mapping = None
Esempio n. 27
0
    def __call__(self, *args, **kwargs):
        """Executes the function on the provided inputs, compiling if necessary."""
        args_flat, _ = jax.tree_flatten((args, kwargs))
        # Use the uncompiled function if the inputs are being traced.
        if any(issubclass(type(arg), jax.core.Tracer) for arg in args_flat):
            return self._function(*args, **kwargs)

        module, out_tree = self._get_compiled_artifacts(args, kwargs)
        results = module.main(*args_flat)
        if results is not None:
            if not isinstance(results, tuple):
                results = (results, )
            return jax.tree_unflatten(out_tree, results)
        else:
            # Address IREE returning None instead of empty sequences.
            if out_tree == jax.tree_flatten([])[-1]:
                return []
            elif out_tree == jax.tree_flatten(())[-1]:
                return ()
            else:
                return results
    def get_fingerprint(self,
                        num_feature_samples=10,
                        num_parameter_samples=10,
                        num_decimals=5):
        """Gets a fingerprint for the functional.

    Fingerprint is evaluated as the MD5 hash value of functional singatures on
    a random sample of feature and parameter values. Signatures will be
    converted to strings with specified number of decimals.

    Args:
      num_feature_samples: Integer, number of samples of features.
      num_parameter_samples: Integer, number of samples of parameters.
      num_decimals: Integer, number of decimals when converting signatures to
        strings.

    Returns:
      String, the fingerprint of the functional.
    """
        format_string = f'{{:.{num_decimals}f}}'

        # fix random seed to have consistent sampling behavior when running the
        # code on different distributed workers
        random_state = np.random.RandomState(0)

        parameter_samples = random_state.rand(num_parameter_samples,
                                              self.num_parameters)

        signatures = []
        for parameters in parameter_samples:
            signatures.extend(
                self.get_signature(**jax.tree_unflatten(
                    self.parameters_spec, parameters),
                                   num_feature_samples=num_feature_samples,
                                   random_state=random_state,
                                   signature='e_xc'))

        signature_string = ','.join(map(format_string.format, signatures))

        return hashlib.md5(signature_string.encode('utf-8')).hexdigest()
Esempio n. 29
0
def _local_value_and_grad_notcentered_kernel(logpsi, pars, vp, mel, v):

    odtype = outdtype(logpsi, pars, v)
    # can use if with jit because that argument is exposed statically to the jit!
    # if real_to_complex:
    if not tree_leaf_iscomplex(pars) and jnp.issubdtype(
            odtype, jnp.complexfloating):
        logpsi_vp_r, f_vjp_r = jax.vjp(lambda w: (logpsi(w, vp).real), pars)
        logpsi_vp_j, f_vjp_j = jax.vjp(lambda w: (logpsi(w, vp).imag), pars)

        logpsi_vp = logpsi_vp_r + 1.0j * logpsi_vp_j

        vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v))
        vec = jnp.asarray(vec, dtype=odtype)

        vec_r = vec.real
        vec_j = vec.imag

        loc_val = vec.sum()

        vr_grad_r, tree_fun = jax.tree_flatten(f_vjp_r(vec_r)[0])
        vj_grad_r, _ = jax.tree_flatten(f_vjp_r(vec_j)[0])
        vr_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_r)[0])
        vj_grad_j, _ = jax.tree_flatten(f_vjp_j(vec_j)[0])

        r_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_r, vj_grad_r)]
        j_flat = [rr + 1j * jr for rr, jr in zip(vr_grad_j, vj_grad_j)]
        out_flat = [re + 1.0j * im for re, im in zip(r_flat, j_flat)]

        grad_c = jax.tree_unflatten(tree_fun, out_flat)
    else:
        logpsi_vp, f_vjp = jax.vjp(lambda w: logpsi(w, vp), pars)

        vec = mel * jax.numpy.exp(logpsi_vp - logpsi(pars, v))
        vec = jnp.asarray(vec, dtype=odtype)

        loc_val = vec.sum()
        grad_c = f_vjp(vec)[0]

    return loc_val, grad_c
Esempio n. 30
0
def sample_simulation_parameters(simulation_parameter_ranges, num_trajectories,
                                 rng):
    """Samples simulation parameters."""

    is_tuple = lambda val: isinstance(val, tuple)
    ranges_flat, ranges_treedef = jax.tree_flatten(simulation_parameter_ranges,
                                                   is_leaf=is_tuple)
    rng, shuffle_rng, *rngs = jax.random.split(rng, len(ranges_flat) + 2)
    shuffle_indices = jax.random.permutation(shuffle_rng, num_trajectories)
    rng_tree = jax.tree_unflatten(ranges_treedef, rngs)

    def sample_simulation_parameter(simulation_parameter_range, parameter_rng):
        """Sample a single simulation parameter."""
        del parameter_rng
        minval, maxval = simulation_parameter_range
        samples = jnp.linspace(minval, maxval, num=num_trajectories)
        return jnp.sort(samples)[shuffle_indices]

    return jax.tree_map(sample_simulation_parameter,
                        simulation_parameter_ranges,
                        rng_tree,
                        is_leaf=is_tuple)