示例#1
0
文件: core_test.py 项目: cloudhan/jax
 def test_comparing_var(self):
     newsym = core.gensym()
     a = newsym(core.ShapedArray((), np.dtype('int32')))
     b = newsym(core.ShapedArray((), np.dtype('int32')))
     c = newsym(core.ShapedArray((), np.dtype('int32')))
     assert a < b < c
     assert c > b > a
     assert a != b and b != c and a != c
示例#2
0
 def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
              named_shape=None):
   super().__init__(shape, dtype)
   named_shape = {} if named_shape is None else named_shape
   self.index_dtype = index_dtype
   self.nnz = nnz
   self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)
   self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype,
                                        named_shape=named_shape)
示例#3
0
    def test_typecheck_staging_nested(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((DBIdx(1), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(a, b):
            @jax.jit
            def g(x):
                return x

            return g(a),

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, m, a, b], keep_inputs=[False, False, True, True])
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (e,) }
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce a type error by applying the called jaxpr to arguments
        # with types which aren't consistent with its input binders:
        _, _, c, d = jaxpr.invars
        jaxpr.eqns[0].invars[1] = d
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a d   !!! type error here !!!
        #   in (e,) }
        with self.assertRaisesRegex(TypeError, "passes operand"):
            core.check_jaxpr(jaxpr)

        # Restore the original jaxpr:
        jaxpr.eqns[0].invars[1] = c
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce another type error by setting the call result let binders
        # to have the wrong type:
        jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval)
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[b] = xla_call[   !!! type error here !!!
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (h,) }
        with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
            core.check_jaxpr(jaxpr)
示例#4
0
  def test_lattice_join_named_shape(self):
    aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
    self.assertEqual(core.lattice_join(aval1, aval1), aval1)

    aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
    expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
    self.assertEqual(core.lattice_join(aval1, aval2), expected)

    aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
    self.assertRaises(TypeError, lambda: core.lattice_join(aval1, aval3))
示例#5
0
文件: mlir.py 项目: rsepassi/jax
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr):
    input_types = map(aval_to_ir_types, avals_in)
    output_types = map(aval_to_ir_types, avals_out)
    flat_output_types = util.flatten(output_types)
    int32_scalar_type = aval_to_ir_type(
        core.ShapedArray((), np.dtype(np.int32)))
    loop_carry_types = [(int32_scalar_type, )] + input_types + output_types
    flat_loop_carry_types = util.flatten(loop_carry_types)
    counter_init = ir_constants(np.array(0, np.int32))
    flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple(
        _dummy_like_aval(aval) for aval in avals_out))
    loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types)
    init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args)

    one = ir_constant(np.array(1, np.int32))
    while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result])

    # Loop condition
    cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(cond_block):
        bool_scalar_type = aval_to_ir_type(
            core.ShapedArray((), np.dtype(np.bool_)))
        two = ir_constant(np.array(2, np.int32))
        shape = ir_constant(np.array((), np.int64), canonicalize_types=False)
        rng = mhlo.RngUniformOp(one, two, shape).result
        i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0],
                                   i32_attr(0))
        cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"),
                             ir.StringAttr.get("SIGNED")).result
        mhlo.ReturnOp([cmp])

    body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type)
    with ir.InsertionPoint(body_block):
        flat_body_args = [
            mhlo.GetTupleElementOp(input_type, body_block.arguments[0],
                                   i32_attr(i)).result
            for i, input_type in enumerate(flat_loop_carry_types)
        ]
        body_args = util.unflatten(flat_body_args, map(len, loop_carry_types))
        ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)])
        body_ctx = ctx.replace(name_stack=xla.extend_name_stack(
            ctx.name_stack, xla.wrap_name(name, 'remat')))
        z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y)
        i_next = mhlo.AddOp(i, one).result
        new_carry = mhlo.TupleOp(loop_carry_tuple_type,
                                 [i_next, *util.flatten(y), *util.flatten(z)])
        mhlo.ReturnOp([new_carry.result])

    outputs = [
        mhlo.GetTupleElementOp(output_type, while_op.result,
                               i32_attr(1 + len(avals_in) + i)).result
        for i, output_type in enumerate(flat_output_types)
    ]
    return util.unflatten(outputs, map(len, output_types))
示例#6
0
 def __init__(self,
              shape,
              dtype,
              index_dtype,
              nnz,
              weak_type=False,
              named_shape={}):
     super(AbstractSparseArray, self).__init__(shape, dtype)
     self.index_dtype = index_dtype
     self.nnz = nnz
     self.data_aval = core.ShapedArray((nnz, ), dtype, weak_type,
                                       named_shape)
     self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype,
                                          named_shape)
示例#7
0
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
  (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
  n_sparse = lhs_indices.shape[-2]
  n_batch = lhs_indices.ndim - 2
  _validate_bcoo(lhs_data, lhs_indices, lhs_shape)

  # Check for proper dimension_numbers
  for dims in [lhs_contracting, rhs_contracting, lhs_batch, rhs_batch]:
    assert len(dims) == len(set(dims))
  assert not set(lhs_contracting).intersection(lhs_batch)
  assert not set(rhs_contracting).intersection(rhs_batch)
  assert [lhs_shape[d] for d in lhs_contracting] == [rhs.shape[d] for d in rhs_contracting]
  assert [lhs_shape[d] for d in lhs_batch] == [rhs.shape[d] for d in rhs_batch]

  if lhs_batch and max(lhs_batch) >= n_batch:
    raise NotImplementedError(
      "bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n"
      f"got lhs_batch={lhs_batch}, n_batch={n_batch}")

  # TODO: support constraction of batch dimensions?
  if any(d < n_batch for d in lhs_contracting):
    raise NotImplementedError("bcoo_dot_general: contracting over batch dimensions.")

  # TODO: support contraction of dense dimensions?
  if any(d >= n_batch + n_sparse for d in lhs_contracting):
    raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.")

  out_dtype = jnp.promote_types(lhs_data.dtype, rhs.dtype)
  out_shape = (tuple(lhs_shape[i] for i in lhs_batch) +
               tuple(s for i, s in enumerate(lhs_shape) if i not in lhs_contracting + lhs_batch) +
               tuple(s for i, s in enumerate(rhs.shape) if i not in rhs_contracting + rhs_batch))
  return core.ShapedArray(out_shape, out_dtype)
示例#8
0
def _bcoo_extract_abstract_eval(indices, mat):
  n_sparse, nse = indices.shape[-2:]
  n_batch = indices.ndim - 2
  n_dense = mat.ndim - n_sparse - n_batch
  assert mat.shape[:n_batch] == indices.shape[:n_batch]
  out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:]
  return core.ShapedArray(out_shape, mat.dtype)
示例#9
0
def _todense_abstract_eval(*bufs, tree):
    arr = tree_util.tree_unflatten(tree, bufs)
    if isinstance(arr, core.ShapedArray):
        return arr
    return core.ShapedArray(arr.shape,
                            arr.dtype,
                            weak_type=dtypes.is_weakly_typed(arr.data))
示例#10
0
 def local_shards(self) -> Sequence[Shard]:
   for s in self._local_shards:
     # Ignore the type because mypy thinks data is None but local_shards
     # cannot have data=None which is checked in `_create_local_shards`.
     if s.data.aval is None:  # type: ignore
       s.data.aval = core.ShapedArray(s.data.shape, s.data.dtype)  # type: ignore
   return self._local_shards
示例#11
0
 def spvalue_to_aval(spvalue):
     if spvalue.is_unit():
         return core.abstract_unit
     else:
         data = spenv.data(spvalue)
         return core.ShapedArray(spvalue.shape, data.dtype,
                                 data.aval.weak_type)
示例#12
0
def _nonzero_translation_rule(c, dims, avals, operands):
  (vals,), = operands
  shape = c.get_shape(vals)
  last_axis = len(shape.dimensions()) - 1
  zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())),
                         shape.dimensions())
  s32_etype = xc.dtype_to_etype(np.dtype('int32'))
  nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype)
  i = core.ShapedArray((), np.dtype('int32'))
  out_dim = xops.Reduce(c, [nonzero_indicators],
                        [xb.constant(c, np.array(0, np.dtype('int32')))],
                        xla.primitive_subcomputation(lax.add_p, i, i),
                        (last_axis,))
  c.get_shape(out_dim)  # xla type checking

  subc = xb.make_computation_builder("sort_gt_comparator")
  params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ()))
            for i in range(4)]
  comparator = subc.build(xops.Gt(params[0], params[1]))
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)],
                  is_stable=True, comparator=comparator)
  _, out_val = xla.xla_destructure(c, ans)
  c.get_shape(out_val)  # xla type checking
  return [[out_dim], [out_val]]
示例#13
0
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
    assert data.shape == row.shape == col.shape
    assert data.dtype == B.dtype
    assert len(shape) == 2
    assert B.shape[0] == shape[0] if transpose else shape[1]
    out_shape = shape[1] if transpose else shape[0]
    return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
示例#14
0
  def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids, expected_is_fully_replicated):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 2)
    self.assertEqual(gda.size, 16)
    self.assertEqual(gda.mesh_axes, mesh_axes)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in gda.local_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
    for s in gda.local_shards:
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)
示例#15
0
def _get_sharding_spec(global_shape, global_mesh, mesh_axes):
    array_mapping = _get_array_mapping(mesh_axes)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    return pxla.mesh_sharding_specs(global_mesh.shape,
                                    global_mesh.axis_names)(aval,
                                                            array_mapping)
示例#16
0
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
                           named_shape_rule, *avals, **kwargs):
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    assert not prim.multiple_results
    weak_type = weak_type_rule(*avals, **kwargs)
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    if least_specialized is core.ConcreteArray:
        out = prim.impl(*[x.val for x in avals], **kwargs)
        return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
    elif least_specialized is core.ShapedArray:
        return core.ShapedArray(shape_rule(*avals, **kwargs),
                                dtype_rule(*avals, **kwargs),
                                weak_type=weak_type,
                                named_shape=named_shape_rule(*avals, **kwargs))
    elif least_specialized is core.DShapedArray:
        shape = shape_rule(*avals, **kwargs)
        ty = (core.ShapedArray if all(type(d) is int
                                      for d in shape) else core.DShapedArray)
        return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
    elif least_specialized is core.UnshapedArray:
        return core.UnshapedArray(dtype_rule(*avals, **kwargs),
                                  weak_type=weak_type)
    else:
        raise TypeError(avals, least_specialized)
示例#17
0
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule,
                                        weak_type_rule, named_shape_rule,
                                        *avals, **kwargs):
    assert prim.multiple_results
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    weak_types = weak_type_rule(*avals, **kwargs)
    if least_specialized is core.ConcreteArray:
        out_vals = prim.impl(*[x.val for x in avals], **kwargs)
        return [
            core.ConcreteArray(val.dtype, val, weak_type=weak_type)
            for val, weak_type in safe_zip(out_vals, weak_types)
        ]
    elif least_specialized is core.ShapedArray:
        out_shapes = shape_rule(*avals, **kwargs)
        out_dtypes = dtype_rule(*avals, **kwargs)
        out_named_shapes = named_shape_rule(*avals, **kwargs)
        return [
            core.ShapedArray(s,
                             d,
                             weak_type=weak_type,
                             named_shape=named_shape)
            for s, d, weak_type, named_shape in safe_zip(
                out_shapes, out_dtypes, weak_types, out_named_shapes)
        ]
    elif least_specialized is core.UnshapedArray:
        out_dtypes = dtype_rule(*avals, **kwargs)
        return [
            core.UnshapedArray(dtype, weak_type=weak_type)
            for dtype, weak_type in safe_zip(out_dtypes, weak_types)
        ]
    else:
        raise TypeError(avals, least_specialized)
示例#18
0
 def fix_float0(arg_jax, ct_arg_jax):
     arg_dtype = dtypes.result_type(arg_jax)  # May be scalar
     ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
     if ct_arg_dtype != ct_arg_jax.dtype:
         return ad_util.zeros_like_aval(
             core.ShapedArray(np.shape(arg_jax), ct_arg_dtype))
     return ct_arg_jax
示例#19
0
 def argspec_to_aval(argspec):
     if argspec.is_unit():
         return core.abstract_unit
     else:
         data = argspec.data(spenv)
         return core.ShapedArray(argspec.shape, data.dtype,
                                 data.aval.weak_type)
示例#20
0
def from_dlpack(dlpack, backend=None):
    """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.

  The returned `DeviceArray` shares memory with `dlpack`.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
    backend: deprecated, do not use.
  """
    if jax.lib._xla_extension_version >= 25:
        cpu_backend = xla_bridge.get_backend("cpu")
        try:
            gpu_backend = xla_bridge.get_backend("gpu")
        except RuntimeError:
            gpu_backend = None
        buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
            dlpack, cpu_backend, gpu_backend)
    else:
        # TODO(phawkins): drop the backend argument after deleting this case.
        backend = backend or xla_bridge.get_backend()
        client = getattr(backend, "client", backend)
        buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)

    xla_shape = buf.xla_shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return xla.make_device_array(aval, buf.device(), buf)  # pytype: disable=attribute-error
示例#21
0
    def test_staging_primitive_applications(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            z = lax.mul(x, y)
            w = lax.sin(z)
            u = lax_internal._reduce_sum(w, [0])
            return (u, )

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars,
                       1 + 2)  # one axis size var, two other inputs
        self.assertLen(jaxpr.eqns, 3)
        self.assertLen(jaxpr.eqns[0].outvars, 1)
        self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
                         jaxpr.invars[1].aval.shape)

        self.assertLen(jaxpr.outvars, 1)
        self.assertEqual(jaxpr.outvars[0].aval.shape, ())
示例#22
0
def compile_and_get_sharding(pjitted_fn, mesh, global_inputs):
  # TODO(yashkatariya): Check if the pjitted_fn comes from pjit.
  inputs = [core.ShapedArray(i.shape, i.dtype) for i in global_inputs]
  compiled = pjitted_fn.lower(*inputs, _global_avals=True).compile()
  in_sharding, out_sharding = pjit._get_sharding_from_executable(
      compiled.runtime_executable(), mesh)
  return _XLAShardingInfo(in_pspec=in_sharding, out_pspec=out_sharding,
                          compiled=compiled)
示例#23
0
def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
    assert data.shape == row.shape == col.shape
    assert data.dtype == v.dtype
    assert row.dtype == col.dtype
    assert len(shape) == 2
    assert v.shape == (shape[0], ) if transpose else (shape[1], )
    out_shape = shape[1] if transpose else shape[0]
    return core.ShapedArray((out_shape, ), data.dtype)
示例#24
0
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh,
                 mesh_axes: MeshAxes) -> Tuple[Index, ...]:
    array_mapping = _get_array_mapping(mesh_axes)
    # The dtype doesn't matter for creating sharding specs.
    aval = core.ShapedArray(global_shape, np.float32)
    sharding_spec = pxla.mesh_sharding_specs(
        global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
    indices = pxla.spec_to_indices(global_shape, sharding_spec)
    return indices  # type: ignore
示例#25
0
文件: dispatch.py 项目: cloudhan/jax
 def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
     if eff not in self.tokens:
         self.tokens[eff] = device_put(np.zeros(0, np.bool_),
                                       device), device
     elif self.tokens[eff][1] != device:
         (old_token, ), _ = self.tokens[eff]
         old_token.aval = core.ShapedArray((0, ), np.bool_)
         self.tokens[eff] = device_put(old_token, device), device
     return self.tokens[eff][0]
示例#26
0
def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
    assert len(shape) == 2
    assert v.ndim == data.ndim == indices.ndim == indptr.ndim == 1
    assert data.shape == indices.shape
    assert data.dtype == v.dtype
    assert indices.dtype == indptr.dtype
    assert len(indptr) == shape[0] + 1
    out_shape = shape[1] if transpose else shape[0]
    assert v.shape == (shape[0], ) if transpose else (shape[1], )
    return core.ShapedArray((out_shape, ), data.dtype)
示例#27
0
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
    assert data.ndim == indices.ndim == indptr.ndim == 1
    assert B.ndim == 2
    assert data.shape == indices.shape
    assert data.dtype == B.dtype
    assert indices.dtype == indptr.dtype
    assert len(indptr) == shape[0] + 1
    out_shape = shape[1] if transpose else shape[0]
    assert B.shape[0] == shape[0] if transpose else shape[1]
    return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
示例#28
0
def shaped_abstractify(x):
  try:
    return core.raise_to_shaped(core.get_aval(x))
  except TypeError:
    pass

  weak_type = getattr(x, 'weak_type', False)
  named_shape = getattr(x, 'named_shape', {})
  return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
                          named_shape=named_shape)
示例#29
0
def discharge_state(jaxpr: core.Jaxpr,
                    consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]:
    """Converts a jaxpr that takes in `Ref`s into one that doesn't."""
    in_avals = [
        core.ShapedArray(v.aval.shape, v.aval.dtype)
        if type(v.aval) is ShapedArrayRef else v.aval for v in jaxpr.invars
    ]
    eval_jaxpr = lu.wrap_init(
        partial(_eval_jaxpr_discharge_state, jaxpr, consts))
    new_jaxpr, _, new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
    return new_jaxpr, new_consts
示例#30
0
文件: prng.py 项目: Jakob-Unfried/jax
def _threefry2x32_abstract_eval(*args):
  if any(a.dtype != jnp.uint32 for a in args):
    raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
                    .format(args))
  if all(isinstance(arg, core.ShapedArray) for arg in args):
    shape = lax._broadcasting_shape_rule(*args)
    named_shape = core.join_named_shapes(*(a.named_shape for a in args))
    aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
  else:
    aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
  return (aval,) * 2