def test_comparing_var(self): newsym = core.gensym() a = newsym(core.ShapedArray((), np.dtype('int32'))) b = newsym(core.ShapedArray((), np.dtype('int32'))) c = newsym(core.ShapedArray((), np.dtype('int32'))) assert a < b < c assert c > b > a assert a != b and b != c and a != c
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False, named_shape=None): super().__init__(shape, dtype) named_shape = {} if named_shape is None else named_shape self.index_dtype = index_dtype self.nnz = nnz self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape) self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype, named_shape=named_shape)
def test_typecheck_staging_nested(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((DBIdx(1), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(a, b): @jax.jit def g(x): return x return g(a), jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, m, a, b], keep_inputs=[False, False, True, True]) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (e,) } core.check_jaxpr(jaxpr) # no problems here... # Let's introduce a type error by applying the called jaxpr to arguments # with types which aren't consistent with its input binders: _, _, c, d = jaxpr.invars jaxpr.eqns[0].invars[1] = d # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[a] = xla_call[ # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a d !!! type error here !!! # in (e,) } with self.assertRaisesRegex(TypeError, "passes operand"): core.check_jaxpr(jaxpr) # Restore the original jaxpr: jaxpr.eqns[0].invars[1] = c core.check_jaxpr(jaxpr) # no problems here... # Let's introduce another type error by setting the call result let binders # to have the wrong type: jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval) # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let # e:f32[b] = xla_call[ !!! type error here !!! # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } # name=g # ] a c # in (h,) } with self.assertRaisesRegex(TypeError, "inconsistently typed as"): core.check_jaxpr(jaxpr)
def test_lattice_join_named_shape(self): aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10}) self.assertEqual(core.lattice_join(aval1, aval1), aval1) aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5}) expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5}) self.assertEqual(core.lattice_join(aval1, aval2), expected) aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5}) self.assertRaises(TypeError, lambda: core.lattice_join(aval1, aval3))
def _remat_using_while(ctx, avals_in, avals_out, *args, name, call_jaxpr): input_types = map(aval_to_ir_types, avals_in) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) int32_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.int32))) loop_carry_types = [(int32_scalar_type, )] + input_types + output_types flat_loop_carry_types = util.flatten(loop_carry_types) counter_init = ir_constants(np.array(0, np.int32)) flat_args = flatten_lowering_ir_args((counter_init, ) + args + tuple( _dummy_like_aval(aval) for aval in avals_out)) loop_carry_tuple_type = ir.TupleType.get_tuple(flat_loop_carry_types) init_carry = mhlo.TupleOp(loop_carry_tuple_type, flat_args) one = ir_constant(np.array(1, np.int32)) while_op = mhlo.WhileOp([loop_carry_tuple_type], [init_carry.result]) # Loop condition cond_block = while_op.regions[0].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(cond_block): bool_scalar_type = aval_to_ir_type( core.ShapedArray((), np.dtype(np.bool_))) two = ir_constant(np.array(2, np.int32)) shape = ir_constant(np.array((), np.int64), canonicalize_types=False) rng = mhlo.RngUniformOp(one, two, shape).result i = mhlo.GetTupleElementOp(int32_scalar_type, cond_block.arguments[0], i32_attr(0)) cmp = mhlo.CompareOp(bool_scalar_type, i, rng, ir.StringAttr.get("LT"), ir.StringAttr.get("SIGNED")).result mhlo.ReturnOp([cmp]) body_block = while_op.regions[1].blocks.append(loop_carry_tuple_type) with ir.InsertionPoint(body_block): flat_body_args = [ mhlo.GetTupleElementOp(input_type, body_block.arguments[0], i32_attr(i)).result for i, input_type in enumerate(flat_loop_carry_types) ] body_args = util.unflatten(flat_body_args, map(len, loop_carry_types)) ((i, ), ), y, _ = util.split_list(body_args, [1, len(avals_in)]) body_ctx = ctx.replace(name_stack=xla.extend_name_stack( ctx.name_stack, xla.wrap_name(name, 'remat'))) z = jaxpr_subcomp(body_ctx, call_jaxpr, (), *y) i_next = mhlo.AddOp(i, one).result new_carry = mhlo.TupleOp(loop_carry_tuple_type, [i_next, *util.flatten(y), *util.flatten(z)]) mhlo.ReturnOp([new_carry.result]) outputs = [ mhlo.GetTupleElementOp(output_type, while_op.result, i32_attr(1 + len(avals_in) + i)).result for i, output_type in enumerate(flat_output_types) ] return util.unflatten(outputs, map(len, output_types))
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False, named_shape={}): super(AbstractSparseArray, self).__init__(shape, dtype) self.index_dtype = index_dtype self.nnz = nnz self.data_aval = core.ShapedArray((nnz, ), dtype, weak_type, named_shape) self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype, named_shape)
def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers n_sparse = lhs_indices.shape[-2] n_batch = lhs_indices.ndim - 2 _validate_bcoo(lhs_data, lhs_indices, lhs_shape) # Check for proper dimension_numbers for dims in [lhs_contracting, rhs_contracting, lhs_batch, rhs_batch]: assert len(dims) == len(set(dims)) assert not set(lhs_contracting).intersection(lhs_batch) assert not set(rhs_contracting).intersection(rhs_batch) assert [lhs_shape[d] for d in lhs_contracting] == [rhs.shape[d] for d in rhs_contracting] assert [lhs_shape[d] for d in lhs_batch] == [rhs.shape[d] for d in rhs_batch] if lhs_batch and max(lhs_batch) >= n_batch: raise NotImplementedError( "bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representtaion.\n" f"got lhs_batch={lhs_batch}, n_batch={n_batch}") # TODO: support constraction of batch dimensions? if any(d < n_batch for d in lhs_contracting): raise NotImplementedError("bcoo_dot_general: contracting over batch dimensions.") # TODO: support contraction of dense dimensions? if any(d >= n_batch + n_sparse for d in lhs_contracting): raise NotImplementedError("bcoo_dot_general: contracting over dense dimensions.") out_dtype = jnp.promote_types(lhs_data.dtype, rhs.dtype) out_shape = (tuple(lhs_shape[i] for i in lhs_batch) + tuple(s for i, s in enumerate(lhs_shape) if i not in lhs_contracting + lhs_batch) + tuple(s for i, s in enumerate(rhs.shape) if i not in rhs_contracting + rhs_batch)) return core.ShapedArray(out_shape, out_dtype)
def _bcoo_extract_abstract_eval(indices, mat): n_sparse, nse = indices.shape[-2:] n_batch = indices.ndim - 2 n_dense = mat.ndim - n_sparse - n_batch assert mat.shape[:n_batch] == indices.shape[:n_batch] out_shape = mat.shape[:n_batch] + (nse,) + mat.shape[mat.ndim - n_dense:] return core.ShapedArray(out_shape, mat.dtype)
def _todense_abstract_eval(*bufs, tree): arr = tree_util.tree_unflatten(tree, bufs) if isinstance(arr, core.ShapedArray): return arr return core.ShapedArray(arr.shape, arr.dtype, weak_type=dtypes.is_weakly_typed(arr.data))
def local_shards(self) -> Sequence[Shard]: for s in self._local_shards: # Ignore the type because mypy thinks data is None but local_shards # cannot have data=None which is checked in `_create_local_shards`. if s.data.aval is None: # type: ignore s.data.aval = core.ShapedArray(s.data.shape, s.data.dtype) # type: ignore return self._local_shards
def spvalue_to_aval(spvalue): if spvalue.is_unit(): return core.abstract_unit else: data = spenv.data(spvalue) return core.ShapedArray(spvalue.shape, data.dtype, data.aval.weak_type)
def _nonzero_translation_rule(c, dims, avals, operands): (vals,), = operands shape = c.get_shape(vals) last_axis = len(shape.dimensions()) - 1 zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())), shape.dimensions()) s32_etype = xc.dtype_to_etype(np.dtype('int32')) nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype) i = core.ShapedArray((), np.dtype('int32')) out_dim = xops.Reduce(c, [nonzero_indicators], [xb.constant(c, np.array(0, np.dtype('int32')))], xla.primitive_subcomputation(lax.add_p, i, i), (last_axis,)) c.get_shape(out_dim) # xla type checking subc = xb.make_computation_builder("sort_gt_comparator") params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ())) for i in range(4)] comparator = subc.build(xops.Gt(params[0], params[1])) iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions()) ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)], is_stable=True, comparator=comparator) _, out_val = xla.xla_destructure(c, ans) c.get_shape(out_val) # xla type checking return [[out_dim], [out_val]]
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose): assert data.shape == row.shape == col.shape assert data.dtype == B.dtype assert len(shape) == 2 assert B.shape[0] == shape[0] if transpose else shape[1] out_shape = shape[1] if transpose else shape[0] return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 2) self.assertEqual(gda.size, 16) self.assertEqual(gda.mesh_axes, mesh_axes) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) self.assertListEqual([i.device.id for i in gda.local_shards], [0, 1, 2, 3, 4, 5, 6, 7]) self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated) for s in gda.local_shards: self.assertEqual(s.data.aval, core.ShapedArray(expected_shard_shape, s.data.dtype)) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertEqual(g.data.aval, l.data.aval) self.assertArraysEqual(g.data, l.data)
def _get_sharding_spec(global_shape, global_mesh, mesh_axes): array_mapping = _get_array_mapping(mesh_axes) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) return pxla.mesh_sharding_specs(global_mesh.shape, global_mesh.axis_names)(aval, array_mapping)
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) if least_specialized is core.ConcreteArray: out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: return core.ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, named_shape=named_shape_rule(*avals, **kwargs)) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) else core.DShapedArray) return ty(shape, dtype_rule(*avals, **kwargs), weak_type) elif least_specialized is core.UnshapedArray: return core.UnshapedArray(dtype_rule(*avals, **kwargs), weak_type=weak_type) else: raise TypeError(avals, least_specialized)
def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ConcreteArray: out_vals = prim.impl(*[x.val for x in avals], **kwargs) return [ core.ConcreteArray(val.dtype, val, weak_type=weak_type) for val, weak_type in safe_zip(out_vals, weak_types) ] elif least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) out_named_shapes = named_shape_rule(*avals, **kwargs) return [ core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape) for s, d, weak_type, named_shape in safe_zip( out_shapes, out_dtypes, weak_types, out_named_shapes) ] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [ core.UnshapedArray(dtype, weak_type=weak_type) for dtype, weak_type in safe_zip(out_dtypes, weak_types) ] else: raise TypeError(avals, least_specialized)
def fix_float0(arg_jax, ct_arg_jax): arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: return ad_util.zeros_like_aval( core.ShapedArray(np.shape(arg_jax), ct_arg_dtype)) return ct_arg_jax
def argspec_to_aval(argspec): if argspec.is_unit(): return core.abstract_unit else: data = argspec.data(spenv) return core.ShapedArray(argspec.shape, data.dtype, data.aval.weak_type)
def from_dlpack(dlpack, backend=None): """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`. The returned `DeviceArray` shares memory with `dlpack`. Args: dlpack: a DLPack tensor, on either CPU or GPU. backend: deprecated, do not use. """ if jax.lib._xla_extension_version >= 25: cpu_backend = xla_bridge.get_backend("cpu") try: gpu_backend = xla_bridge.get_backend("gpu") except RuntimeError: gpu_backend = None buf = xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, cpu_backend, gpu_backend) else: # TODO(phawkins): drop the backend argument after deleting this case. backend = backend or xla_bridge.get_backend() client = getattr(backend, "client", backend) buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client) xla_shape = buf.xla_shape() assert not xla_shape.is_tuple() aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) return xla.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
def test_staging_primitive_applications(self): n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) a = core.DShapedArray((DBIdx(0), ), jnp.dtype('float32'), weak_type=False) b = core.DShapedArray((DBIdx(0), ), jnp.dtype('float32'), weak_type=False) @lu.wrap_init def f(x, y): z = lax.mul(x, y) w = lax.sin(z) u = lax_internal._reduce_sum(w, [0]) return (u, ) jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( f, [n, a, b], keep_inputs=[False, True, True]) self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs self.assertLen(jaxpr.eqns, 3) self.assertLen(jaxpr.eqns[0].outvars, 1) self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape, jaxpr.invars[1].aval.shape) self.assertLen(jaxpr.outvars, 1) self.assertEqual(jaxpr.outvars[0].aval.shape, ())
def compile_and_get_sharding(pjitted_fn, mesh, global_inputs): # TODO(yashkatariya): Check if the pjitted_fn comes from pjit. inputs = [core.ShapedArray(i.shape, i.dtype) for i in global_inputs] compiled = pjitted_fn.lower(*inputs, _global_avals=True).compile() in_sharding, out_sharding = pjit._get_sharding_from_executable( compiled.runtime_executable(), mesh) return _XLAShardingInfo(in_pspec=in_sharding, out_pspec=out_sharding, compiled=compiled)
def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose): assert data.shape == row.shape == col.shape assert data.dtype == v.dtype assert row.dtype == col.dtype assert len(shape) == 2 assert v.shape == (shape[0], ) if transpose else (shape[1], ) out_shape = shape[1] if transpose else shape[0] return core.ShapedArray((out_shape, ), data.dtype)
def _get_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Tuple[Index, ...]: array_mapping = _get_array_mapping(mesh_axes) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) return indices # type: ignore
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken: if eff not in self.tokens: self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device elif self.tokens[eff][1] != device: (old_token, ), _ = self.tokens[eff] old_token.aval = core.ShapedArray((0, ), np.bool_) self.tokens[eff] = device_put(old_token, device), device return self.tokens[eff][0]
def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): assert len(shape) == 2 assert v.ndim == data.ndim == indices.ndim == indptr.ndim == 1 assert data.shape == indices.shape assert data.dtype == v.dtype assert indices.dtype == indptr.dtype assert len(indptr) == shape[0] + 1 out_shape = shape[1] if transpose else shape[0] assert v.shape == (shape[0], ) if transpose else (shape[1], ) return core.ShapedArray((out_shape, ), data.dtype)
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): assert data.ndim == indices.ndim == indptr.ndim == 1 assert B.ndim == 2 assert data.shape == indices.shape assert data.dtype == B.dtype assert indices.dtype == indptr.dtype assert len(indptr) == shape[0] + 1 out_shape = shape[1] if transpose else shape[0] assert B.shape[0] == shape[0] if transpose else shape[1] return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def shaped_abstractify(x): try: return core.raise_to_shaped(core.get_aval(x)) except TypeError: pass weak_type = getattr(x, 'weak_type', False) named_shape = getattr(x, 'named_shape', {}) return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type, named_shape=named_shape)
def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]: """Converts a jaxpr that takes in `Ref`s into one that doesn't.""" in_avals = [ core.ShapedArray(v.aval.shape, v.aval.dtype) if type(v.aval) is ShapedArrayRef else v.aval for v in jaxpr.invars ] eval_jaxpr = lu.wrap_init( partial(_eval_jaxpr_discharge_state, jaxpr, consts)) new_jaxpr, _, new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals) return new_jaxpr, new_consts
def _threefry2x32_abstract_eval(*args): if any(a.dtype != jnp.uint32 for a in args): raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}" .format(args)) if all(isinstance(arg, core.ShapedArray) for arg in args): shape = lax._broadcasting_shape_rule(*args) named_shape = core.join_named_shapes(*(a.named_shape for a in args)) aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape) else: aval = core.UnshapedArray(jnp.dtype(jnp.uint32)) return (aval,) * 2