コード例 #1
0
ファイル: base.py プロジェクト: seanmb/probability
def custom_layer_cau_batch(trace, f, tracers, params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers)
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(f, *vals, **params)
    args = tree_util.tree_unflatten(params['in_tree'], vals)
    dims_ = [not_mapped if dim is None else dim for dim in dims]
    layer, args = args[0], args[1:]
    if hasattr(layer, '_call_and_update_batched'):
        num_params = len(tree_util.tree_leaves(layer))
        layer_dims, arg_dims = dims_[:num_params], dims_[num_params:]
        if params['kwargs']['has_rng']:
            rng, args = args[0], args[1:]
            rng_dim, arg_dims = arg_dims[0], arg_dims[1:]
        mapping_over_layer = all(layer_dim is not not_mapped
                                 for layer_dim in layer_dims)
        mapping_over_args = all(arg_dim is not not_mapped
                                for arg_dim in arg_dims)
        assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims)
        if not mapping_over_layer and mapping_over_args:
            if params['kwargs']['has_rng']:
                if rng_dim is not not_mapped:
                    arg_dims = tuple(None if dim is not_mapped else dim
                                     for dim in arg_dims)
                    map_fun = jax.vmap(
                        lambda layer, rng, *args: _layer_cau_batched(
                            layer,
                            rng,
                            *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                            **params['kwargs']),
                        in_axes=(None, rng_dim) + (None, ) * len(arg_dims))
                else:
                    map_fun = lambda layer, *args: _layer_cau_batched(
                        layer,
                        *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                        **params['kwargs'])
                vals_out, update_out = map_fun(layer, rng, *args)
            else:
                vals_out, update_out = _layer_cau_batched(
                    layer, *args, **params['kwargs'])
            vals_out = tree_util.tree_leaves(vals_out)
            update_out = tree_util.tree_leaves(update_out)
            assert all(dim == 0 for dim in arg_dims)
            # Assume dimensions out are consistent
            dims_out = (0, ) * len(vals_out)
            dims_update = (None, ) * len(update_out)
            dims_out = dims_out + dims_update

            # Call wrapped function to avoid linear_util error
            f.call_wrapped(*tracers)
            return [
                batching.BatchTracer(trace, v, d)
                for v, d in zip(vals_out + update_out, dims_out + dims_update)
            ]
    f, dims_out = batching.batch_subtrace(f, trace.master, dims)
    vals_out = layer_cau_p.subcall('batch').bind(f, *vals, **params)
    return [
        batching.BatchTracer(trace, v, d)
        for v, d in zip(vals_out, dims_out())
    ]
コード例 #2
0
ファイル: parallel.py プロジェクト: satishreddycrec/jax
def _process_axis_index(self, frame):
  return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)