Ejemplo n.º 1
0
def custom_layer_cau_batch(trace, f, tracers, params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers)
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(f, *vals, **params)
    args = tree_util.tree_unflatten(params['in_tree'], vals)
    dims_ = [not_mapped if dim is None else dim for dim in dims]
    layer, args = args[0], args[1:]
    if hasattr(layer, '_call_and_update_batched'):
        num_params = len(tree_util.tree_leaves(layer))
        layer_dims, arg_dims = dims_[:num_params], dims_[num_params:]
        if params['kwargs']['has_rng']:
            rng, args = args[0], args[1:]
            rng_dim, arg_dims = arg_dims[0], arg_dims[1:]
        mapping_over_layer = all(layer_dim is not not_mapped
                                 for layer_dim in layer_dims)
        mapping_over_args = all(arg_dim is not not_mapped
                                for arg_dim in arg_dims)
        assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims)
        if not mapping_over_layer and mapping_over_args:
            if params['kwargs']['has_rng']:
                if rng_dim is not not_mapped:
                    arg_dims = tuple(None if dim is not_mapped else dim
                                     for dim in arg_dims)
                    map_fun = jax.vmap(
                        lambda layer, rng, *args: _layer_cau_batched(
                            layer,
                            rng,
                            *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                            **params['kwargs']),
                        in_axes=(None, rng_dim) + (None, ) * len(arg_dims))
                else:
                    map_fun = lambda layer, *args: _layer_cau_batched(
                        layer,
                        *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                        **params['kwargs'])
                vals_out, update_out = map_fun(layer, rng, *args)
            else:
                vals_out, update_out = _layer_cau_batched(
                    layer, *args, **params['kwargs'])
            vals_out = tree_util.tree_leaves(vals_out)
            update_out = tree_util.tree_leaves(update_out)
            assert all(dim == 0 for dim in arg_dims)
            # Assume dimensions out are consistent
            dims_out = (0, ) * len(vals_out)
            dims_update = (None, ) * len(update_out)
            dims_out = dims_out + dims_update

            # Call wrapped function to avoid linear_util error
            f.call_wrapped(*tracers)
            return [
                batching.BatchTracer(trace, v, d)
                for v, d in zip(vals_out + update_out, dims_out + dims_update)
            ]
    f, dims_out = batching.batch_subtrace(f, trace.master, dims)
    vals_out = layer_cau_p.subcall('batch').bind(f, *vals, **params)
    return [
        batching.BatchTracer(trace, v, d)
        for v, d in zip(vals_out, dims_out())
    ]
Ejemplo n.º 2
0
def batch_jaxpr(jaxpr, axis_size, in_dims):
  dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders)
  in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders]

  in_avals = [core.unmapped_aval(axis_size, d, aval)
              if d is not batching.not_mapped else aval
              for d, aval in zip(in_dims, in_avals)]

  fun, out_dims = batching.batch_subtrace(lu.wrap_init(jaxpr_as_fun(jaxpr)))
  f = _batch_fun(fun, in_dims)
  jaxpr, consts, _ = trace_to_jaxpr_dynamic(f, in_avals)
  return jaxpr, consts, out_dims()
Ejemplo n.º 3
0
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
    f, out_axes = batching.batch_subtrace(f)
    f = batching._batch_outer(f, axis_name, axis_size, in_axes,
                              batching.BatchTrace)
    outs = f.call_wrapped(*args)
    return outs, out_axes()
Ejemplo n.º 4
0
def batch_fun(fun: lu.WrappedFun, in_dims):
    fun, out_dims = batching.batch_subtrace(fun)
    return _batch_fun(fun, in_dims), out_dims