Ejemplo n.º 1
0
def mpi_allgather_abstract_eval(x, token, comm):
    comm = unpack_hashable(comm)
    size = comm.Get_size()
    out_shape = (size, *x.shape)
    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        abstract_arrays.abstract_token,
    )
Ejemplo n.º 2
0
 def H_abstract_eval(x, d=0, full=False):
     if full:
         dim1 = self.basisClass.m
     else:
         dim1 = self.basisClass.m - self.basisClass.numC
     if len(x.shape) == 0:
         dims = (dim1, )
     else:
         dims = (x.shape[0], dim1)
     return abstract_arrays.ShapedArray(dims, x.dtype)
Ejemplo n.º 3
0
def _threefry2x32_abstract_eval(*args):
  if any(a.dtype != np.uint32 for a in args):
    raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
                    .format(args))
  if all(isinstance(arg, abstract_arrays.ShapedArray) for arg in args):
    shape = lax._broadcasting_shape_rule(*args)
    aval = abstract_arrays.ShapedArray(shape, np.dtype(np.uint32))
  else:
    aval = abstract_arrays.UnshapedArray(np.dtype(np.uint32))
  return (aval,) * 2
Ejemplo n.º 4
0
Archivo: mtfc.py Proyecto: leakec/tfc
 def H_abstract_eval(*x, d=d0, full=False):
     if full:
         dim1 = self.basisClass.numBasisFuncFull
     else:
         dim1 = self.basisClass.numBasisFunc
     if len(x[0].shape) == 0:
         dims = (dim1, )
     else:
         dims = (x[0].shape[0], dim1)
     return abstract_arrays.ShapedArray(dims, x[0].dtype)
Ejemplo n.º 5
0
def mpi_scatter_abstract_eval(x, token, root, comm):
    comm = unpack_hashable(comm)
    rank = comm.Get_rank()
    if rank == root:
        out_shape = x.shape[1:]
    else:
        out_shape = x.shape

    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        core.abstract_token,
    )
Ejemplo n.º 6
0
def multiply_add_abstract_eval(xs, ys, zs):
    """Abstract evaluation of the primitive.

    This function does not need to be JAX traceable. It will be invoked with
    abstractions of the actual arguments.
    Args:
    xs, ys, zs: abstractions of the arguments.
    Result:
    a ShapedArray for the result of the primitive.
    """
    assert xs.shape == ys.shape
    assert xs.shape == zs.shape
    return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
Ejemplo n.º 7
0
def mpi_gather_abstract_eval(x, token, root, comm):
    comm = unpack_hashable(comm)
    rank = comm.Get_rank()
    size = comm.Get_size()

    if rank == root:
        out_shape = (size, *x.shape)
    else:
        out_shape = x.shape

    return (
        abstract_arrays.ShapedArray(out_shape, x.dtype),
        core.abstract_token,
    )
Ejemplo n.º 8
0
def mpi_sendrecv_abstract_eval(
    sendbuf,
    recvbuf,
    token,
    source,
    dest,
    sendtag,
    recvtag,
    comm,
    status,
    _must_transpose=False,
):
    return (
        abstract_arrays.ShapedArray(recvbuf.shape, recvbuf.dtype),
        core.abstract_token,
    )
Ejemplo n.º 9
0
    def build_output_vals(self, scope, carried_state_names, carried_tree,
                          init_vals, body_typed_jaxpr, body_const_vals):
        # Trace the conditional function. cond_func takes 0 arguments, but
        # for lax.while we need a conditional function that takes the
        # carried_state_names. _initial_style_jaxpr will start its own trace and
        # will create tracers for all the carried state. We must put these values
        # in the scope._mutable_state before we trace the conditional
        # function.
        def cond_func_wrapped(*args):
            assert len(args) == len(carried_state_names)
            for ms, init_ms in zip(carried_state_names, args):
                scope._mutable_state[ms] = init_ms
            res = self.cond_func()
            # Conditional function is not allowed to modify the scope state
            for ms, init_ms in zip(carried_state_names, args):
                if not (scope._mutable_state[ms] is init_ms):
                    msg = "Conditional function modifies scope.{} field."
                    raise ValueError(msg.format(ms))
            return res

        init_avals = safe_map(_BodyTracer.abstractify, init_vals)
        cond_jaxpr, cond_consts, cond_tree = (
            lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
                                                  carried_tree,
                                                  tuple(init_avals)))
        # TODO: share these checks with lax_control_flow.while
        if not tree_util.treedef_is_leaf(cond_tree):
            msg = "cond_fun must return a boolean scalar, but got pytree {}."
            raise TypeError(msg.format(cond_tree))
        if cond_jaxpr.out_avals != [
                abstract_arrays.ShapedArray((), onp.bool_)
        ]:
            msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
            raise TypeError(msg.format(cond_jaxpr.out_avals))

        return lax_control_flow.while_p.bind(*itertools.chain(
            cond_consts, body_const_vals, init_vals),
                                             cond_nconsts=len(cond_consts),
                                             cond_jaxpr=cond_jaxpr,
                                             body_nconsts=len(body_const_vals),
                                             body_jaxpr=body_typed_jaxpr)
Ejemplo n.º 10
0
def mpi_recv_abstract_eval(xs, token, source, tag, comm, status):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )
Ejemplo n.º 11
0
def random_normal_abstract(key, **_):
    del key
    return abstract_arrays.ShapedArray((), jnp.float32)
Ejemplo n.º 12
0
def abstractify(t: Union[tf.Tensor, tf.Variable]):
    return abstract_arrays.ShapedArray(tuple(t.shape), t.dtype.as_numpy_dtype)
Ejemplo n.º 13
0
def mpi_reduce_abstract_eval(xs, token, op, root, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        core.abstract_token,
    )
Ejemplo n.º 14
0
 def slice_aval(aval):
     return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype,
                                        aval.weak_type)
Ejemplo n.º 15
0
def abstractify(t: tf.Tensor):
    return abstract_arrays.ShapedArray(tuple(t.shape), t.dtype.as_numpy_dtype)
Ejemplo n.º 16
0
def superbee_abstract_eval(var, u_wgrid, v_wgrid, w_wgrid, maskW, dxt, dyt,
                           dzw, cost, cosu, dt_tracer):

    aarr = abstract_arrays.ShapedArray(var.shape, var.dtype)
    return (aarr, ) * 3
Ejemplo n.º 17
0
def sum_inplace_abstract_eval(xs, comm):
    return abstract_arrays.ShapedArray(xs.shape, xs.dtype)
Ejemplo n.º 18
0
def mpi_alltoall_abstract_eval(xs, token, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        core.abstract_token,
    )
Ejemplo n.º 19
0
def mpi_bcast_abstract_eval(xs, token, root, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )
Ejemplo n.º 20
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear, unroll):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals,
                                       *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        jax.lax.lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear,
                           unroll=unroll)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys
Ejemplo n.º 21
0
def random_normal_abstract(_):
    return abstract_arrays.ShapedArray((), np.float32)
Ejemplo n.º 22
0
def _get_partial_value(object):
    # ShapedArrays are abstract values that carry around
    # shape and dtype information
    aval = j_abstract_arrays.ShapedArray(numpy.shape(object), numpy.result_type(object))
    result = ji_partial_eval.PartialVal((aval, j_core.unit))
    return result
Ejemplo n.º 23
0
def mpi_sendrecv_abstract_eval(sendbuf, recvbuf, token, source, dest, sendtag,
                               recvtag, comm, status):
    return (
        abstract_arrays.ShapedArray(recvbuf.shape, recvbuf.dtype),
        abstract_arrays.abstract_token,
    )
Ejemplo n.º 24
0
def get_shaped_aval(x):
    """Converts a JAX value type into a shaped abstract value."""
    if hasattr(x, 'dtype') and hasattr(x, 'shape'):
        return abstract_arrays.ShapedArray(x.shape,
                                           dtypes.canonicalize_dtype(x.dtype))
    return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
Ejemplo n.º 25
0
def _random_normal_abstract(key, loc, scale):
  del key, loc, scale
  return [abstract_arrays.ShapedArray((), np.float32)]
def random_normal_abstract(_, name=None):
    del name
    return abstract_arrays.ShapedArray((), np.float32)
Ejemplo n.º 27
0
def tridiag_abstract_eval(a, b, c, d):
    return abstract_arrays.ShapedArray(a.shape, a.dtype)
Ejemplo n.º 28
0
def mpi_allreduce_abstract_eval(xs, token, op, comm):
    return (
        abstract_arrays.ShapedArray(xs.shape, xs.dtype),
        abstract_arrays.abstract_token,
    )
Ejemplo n.º 29
0
def get_shaped_aval(x):
    if hasattr(x, 'dtype') and hasattr(x, 'shape'):
        return abstract_arrays.ShapedArray(x.shape, x.dtype)
    return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))