def standard_multi_result_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, named_shape_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals least_specialized = _max( map(type, avals), key=operator.attrgetter('array_abstraction_level')) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ConcreteArray: out_vals = prim.impl(*[x.val for x in avals], **kwargs) return [ core.ConcreteArray(val.dtype, val, weak_type=weak_type) for val, weak_type in safe_zip(out_vals, weak_types) ] elif least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) out_named_shapes = named_shape_rule(*avals, **kwargs) return [ core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape) for s, d, weak_type, named_shape in safe_zip( out_shapes, out_dtypes, weak_types, out_named_shapes) ] elif least_specialized is core.UnshapedArray: out_dtypes = dtype_rule(*avals, **kwargs) return [ core.UnshapedArray(dtype, weak_type=weak_type) for dtype, weak_type in safe_zip(out_dtypes, weak_types) ] else: raise TypeError(avals, least_specialized)
def _broadcast_to(arr, shape): if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) _check_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, ndarray) else _asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape, ) shape = core.canonicalize_shape(shape) # check that shape is concrete arr_shape = np.shape(arr) if core.symbolic_equal_shape(arr_shape, shape): return arr else: nlead = len(shape) - len(arr_shape) shape_tail = shape[nlead:] compatible = all( core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) if nlead < 0 or not compatible: msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) diff, = np.where( tuple(not core.symbolic_equal_dim(arr_d, shape_d) for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) new_dims = tuple(range(nlead)) + tuple(nlead + diff) kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xla.parameter(subc, i, ctx.builder.GetShape(n)) args.append(xla.set_sharding(subc, arg, sharding)) sub_ctx = ctx.replace( builder=subc, name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xla.xla_destructure(ctx.builder, xops.Call(ctx.builder, subc, list(in_nodes)))
def _get_shard_indices_replica_ids_uncached( global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]: indices = _get_indices(global_shape, global_mesh, mesh_axes) index_to_replica: Dict[int, int] = Counter() out = {} unique_shards = 0 for device, index in safe_zip(global_mesh.devices.flat, indices): h_index = _hashed_index(index) replica_id = index_to_replica[h_index] if replica_id == 0: unique_shards += 1 index_to_replica[h_index] += 1 out[device] = (index, replica_id) shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes) expected_unique_shards = prod([ g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0 ]) if expected_unique_shards != unique_shards: raise RuntimeError( f'Number of expected unique shards are: {expected_unique_shards} but ' f'got {unique_shards}. Please file a bug at ' 'https://github.com/google/jax/issues.') return out
def CheckShapePolymorphism(self, f_jax: Callable, *, input_signature: Sequence[tf.TensorSpec], polymorphic_shapes: Optional[Sequence[Any]], expected_output_signature: Optional[tf.TensorSpec] = None, enable_xla: bool = True): """Converts a function using polymorphic shapes. Args: f_jax: a JAX function of `n` arguments input_signature: used as the input signature for the tf.function. polymorphic_shapes: Specifies input shapes to be treated polymorphically during conversion. expected_output_signature: if given, this function tests whether the actual output signature is equal to this one. enable_xla: Whether to enable XLA conversion for jax2tf.convert. """ f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes, enable_xla=enable_xla) f_tf = tf.function(f_tf, autograph=False, input_signature=input_signature) concrete_f_tf = f_tf.get_concrete_function(*input_signature) if expected_output_signature: # Strangely, output_shapes can be a single shape for a function with a # single result, or a list/tuple of shapes. concrete_output_tf_shape = concrete_f_tf.output_shapes if not isinstance(concrete_output_tf_shape, (tuple, list)): # Single result assert not isinstance(expected_output_signature, (tuple, list)) expected_output_signature = [expected_output_signature] concrete_output_tf_shape = [concrete_output_tf_shape] for expected, found in util.safe_zip(expected_output_signature, concrete_output_tf_shape): self.assertEqual(tuple(expected.shape), tuple(found)) return f_tf
def _xla_sharded_args(c, avals, in_parts): xla_args = [] for i, (sharding, aval) in enumerate(safe_zip(in_parts, avals)): param = xla.with_sharding(c, sharding, xla.parameter, c, i, *xla.aval_to_xla_shapes(aval)) xla_args.append(param) return xla_args
def CheckShapePolymorphism(self, f_jax: Callable, *, input_signature: Sequence[tf.TensorSpec], polymorphic_shapes: Optional[Sequence[Any]], expected_output_signature: tf.TensorSpec): """Convert a function using polymorphic shapes. Args: f_jax: a JAX function of `n` arguments input_signature: used as the input signature for the tf.function. in_shapes: if given, it must be a sequence of `n` shape specifications and must match the `input_signature`. (see jax2tf.convert). """ f_tf = tf.function( jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes), autograph=False, input_signature=input_signature) concrete_f_tf = f_tf.get_concrete_function(*input_signature) if expected_output_signature: # Strangely, output_shapes can be a single shape for a function with a # single result, or a list/tuple of shapes. concrete_output_tf_shape = concrete_f_tf.output_shapes if not isinstance(concrete_output_tf_shape, (tuple, list)): # Single result assert not isinstance(expected_output_signature, (tuple, list)) expected_output_signature = [expected_output_signature] concrete_output_tf_shape = [concrete_output_tf_shape] for expected, found in util.safe_zip(expected_output_signature, concrete_output_tf_shape): self.assertEqual(tuple(expected.shape), tuple(found)) return f_tf
def _fft_core(func_name, fft_type, a, s, axes, norm): full_name = "jax.numpy.fft." + func_name if s is not None: s = tuple(map(operator.index, s)) if np.any(np.less(s, 0)): raise ValueError("Shape should be non-negative.") if norm is not None: raise NotImplementedError("%s only supports norm=None, got %s" % (full_name, norm)) if s is not None and axes is not None and len(s) != len(axes): # Same error as numpy. raise ValueError("Shape and axes have different lengths.") orig_axes = axes if axes is None: if s is None: axes = range(a.ndim) else: axes = range(a.ndim - len(s), a.ndim) if len(axes) != len(set(axes)): raise ValueError("%s does not support repeated axes. Got axes %s." % (full_name, axes)) if len(axes) > 3: # XLA does not support FFTs over more than 3 dimensions raise ValueError("%s only supports 1D, 2D, and 3D FFTs. " "Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim)) # XLA only supports FFTs over the innermost axes, so rearrange if necessary. if orig_axes is not None: axes = tuple(range(a.ndim - len(axes), a.ndim)) a = jnp.moveaxis(a, orig_axes, axes) if s is not None: a = jnp.asarray(a) in_s = list(a.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x if fft_type == xla_client.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping a = a[tuple(map(slice, in_s))] # Padding a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)]) else: if fft_type == xla_client.FftType.IRFFT: s = [a.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (a.shape[axes[-1]] - 1))] else: s = [a.shape[axis] for axis in axes] transformed = lax.fft(a, fft_type, s) if orig_axes is not None: transformed = jnp.moveaxis(transformed, axes, orig_axes) return transformed
def _get_shard_indices_replica_ids_uncached( global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]: indices = _get_indices(global_shape, global_mesh, mesh_axes) replica_ids = _calc_replica_ids(global_mesh, mesh_axes) return dict((d, (i, r)) for d, i, r in safe_zip(global_mesh.devices.flat, indices, replica_ids))
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 2) self.assertEqual(gda.size, 16) self.assertEqual(gda.mesh_axes, mesh_axes) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) self.assertListEqual([i.device.id for i in gda.local_shards], [0, 1, 2, 3, 4, 5, 6, 7]) self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated) for s in gda.local_shards: self.assertEqual(s.data.aval, core.ShapedArray(expected_shard_shape, s.data.dtype)) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertEqual(g.data.aval, l.data.aval) self.assertArraysEqual(g.data, l.data)
def test_gda_subset_devices(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.local_shards[0].index, expected_index[0]) self.assertArraysEqual(gda.local_data(0), global_input_data[expected_index[0]]) self.assertEqual(gda.local_shards[1].index, expected_index[1]) self.assertArraysEqual(gda.local_data(1), global_input_data[expected_index[1]]) self.assertEqual(gda.local_data(0).shape, expected_shard_shape) replica_ids = [i.replica_id for i in gda.local_shards] self.assertListEqual(replica_ids, expected_replica_ids) for g, l in safe_zip(gda.global_shards, gda.local_shards): self.assertEqual(g.device, l.device) self.assertEqual(g.index, l.index) self.assertEqual(g.replica_id, l.replica_id) self.assertArraysEqual(g.data, l.data)
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Index]: indices = _get_indices(global_shape, global_mesh, mesh_axes) # The type: ignore is to ignore the type returned by `spec_to_indices`. return {d: i for d, i in safe_zip(global_mesh.devices.flat, indices) } # type: ignore
def _xla_callable_args(c, avals, tuple_args, *, replicated=None, partitions=None, partitions_proto: bool = False, donated_invars=None, filter_tokens=True): assert partitions is None or len(partitions) == len(avals) if not tuple_args: if replicated is None: replicated = [None] * len(avals) if partitions is None: parts: List[object] = [None] * len(avals) elif partitions_proto: parts = partitions else: parts = [ _replicated_param if part is None else part for part in partitions ] counts = it.count() xla_args = [ _xla_param(c, next(counts), xla_shape, r, p, partitions_proto, filter_tokens) for (a, r, p) in safe_zip(avals, replicated, parts) for xla_shape in aval_to_xla_shapes(a) ] if donated_invars is not None: donated_invars = [ d for (a, _, _, d) in zip(avals, replicated, parts, donated_invars) for xla_shape in aval_to_xla_shapes(a) ] return xla_args, donated_invars else: if replicated is not None: replicated = [ r for a, r in zip(avals, replicated) if a is not abstract_token ] if partitions is None: tuple_parts = None elif partitions_proto: tuple_parts = tuple_sharding_proto(partitions) else: tuple_parts = tuple(partitions) tuple_shape = xc.Shape.tuple_shape([ shape if not (filter_tokens and a is abstract_token) else _token_param_shape() for a in avals for shape in aval_to_xla_shapes(a) ]) tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto, filter_tokens) xla_args = [ v if not (filter_tokens and a is abstract_token) else xops.CreateToken(c) for a, v in zip(avals, xla_destructure(c, tuple_param)) ] return xla_args, donated_invars
def _avals_to_results_handler(nrep, npart, partitions, out_avals): handlers = [_aval_to_result_handler(npart, parts, out_aval) for parts, out_aval in safe_zip(partitions, out_avals)] def handler(out_bufs): return [h(bufs) for h, bufs in zip(handlers, out_bufs)] return handler
def _map_to_tile(*args_flat): sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) tile_size_ = tile_size or next(sizes, None) assert tile_size_ is not None, "No mapped arguments?" outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} yield map(untile_axis, outputs_flat, out_axes_flat)
def _conv_general_vjp_lhs_padding(in_shape, window_dimensions, window_strides, out_shape, padding, lhs_dilation, rhs_dilation) -> List[Tuple[int, int]]: lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation) rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation) out_dilated_shape = lax._dilate_shape(out_shape, window_strides) pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1 pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1 - out_dilated_shape - pad_before) return safe_zip(pad_before, pad_after)
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs
def get_shard_indices_replica_ids( global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Tuple[Index, int]]: indices = _get_indices(global_shape, global_mesh, mesh_axes) index_to_replica: Dict[_HashableIndex, int] = Counter() out = {} for device, index in safe_zip(global_mesh.devices.flat, indices): h_index = _HashableIndex(index) replica_id = index_to_replica[h_index] index_to_replica[h_index] += 1 out[device] = (index, replica_id) return out
def _gsda_shard_arg(x, devices, indices): pjit_mesh = maps.thread_resources.env.physical_mesh if x._global_mesh != pjit_mesh: raise ValueError( "Pjit's mesh and GDA's mesh should be equal. Got Pjit " f"mesh: {pjit_mesh},\n GDA mesh: {x._global_mesh}") assert all(g.index == i for g, i in safe_zip(x.global_shards, indices)), ( "Indices calculated by GDA and pjit do not match. Please file a bug " "on https://github.com/google/jax/issues. " f"Got GDA indices: {[g.index for g in x.global_shards]},\n" f"pjit indices: {indices}") return [s.data for s in x.local_shards]
def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict], name: str, target_total_secs: int = None): """Benchmarks a function for several combinations of parameters. Prints the summarized results in a table.. Args: prepare: given kwargs returns a benchmark function specialized to the kwargs. params_list: a list of kwargs on which to run the benchmark. name: the name of this benchmark suite target_total_secs: the ``target_total_secs`` to pass to ``benchmark``. """ # Sort parameters alphabetically so benchmark results print consistently. params_list = [OrderedDict(sorted(p.items())) for p in params_list] assert all(p.keys() == params_list[0].keys() for p in params_list) times = [] for params in params_list: f = prepare(**params) subname = name + "".join("_%s=%s" % (n, _param_str(p)) for n, p in params.items()) times.append( benchmark(f, name=subname, target_total_secs=target_total_secs)) param_names = list(params_list[0].keys()) data_header = param_names + ["mean", "%std", "relative"] data = [ list(map(_param_str, params.values())) + [t.mean(), _pstd(t), t.mean() / times[0].mean()] for params, t in safe_zip(params_list, times) ] if FLAGS.baseline_dir: mean_idx = len(param_names) means = _get_baseline_means(FLAGS.baseline_dir, name) assert len(means) == len(data), (means, data) data_header.append("mean/baseline") for idx, mean in enumerate(means): data[idx].append(data[idx][mean_idx] / mean) print("---------Benchmark summary for %s---------" % name) print(tabulate(data, data_header)) print() if FLAGS.export_dir: filename = _export_results(data_header, data, FLAGS.export_dir, name) print("Wrote %s results to %s" % (name, filename)) print()
def __init__(self, global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes, device_buffers: Sequence[DeviceArray], _gda_fast_path_args: Optional[_GdaFastPathArgs] = None, _enable_checks: bool = True): self._global_shape = global_shape self._global_mesh = global_mesh self._mesh_axes = mesh_axes self._device_buffers = device_buffers # Optionally precomputed for performance. self._gda_fast_path_args = _gda_fast_path_args self._current_process = xb.process_index() if self._gda_fast_path_args is None: self._local_devices = self._global_mesh.local_devices else: self._local_devices = self._gda_fast_path_args.local_devices if _enable_checks or config.jax_enable_checks: for db, ld in safe_zip(device_buffers, self._local_devices): if db.device() != ld: raise ValueError( "The `global_mesh.local_devices` and `device_buffers` device " "order doesn't match. Please use `global_mesh.local_devices` to " "put arrays on devices instead of `jax.local_devices()`" ) if _enable_checks or config.jax_enable_checks: ss = get_shard_shape(self._global_shape, self._global_mesh, self.mesh_axes) assert all(db.shape == ss for db in device_buffers), ( f"Expected shard shape {ss} doesn't match the device buffer " f"shape, got: {[db.shape for db in device_buffers]}") dtype = device_buffers[0].dtype if _enable_checks or config.jax_enable_checks: assert all(db.dtype == dtype for db in device_buffers), ( "Input arrays to GlobalDeviceArray must have matching dtypes, " f"got: {[db.dtype for db in device_buffers]}") self.dtype = dtype
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh, mesh_axes: MeshAxes) -> Mapping[Device, Index]: # Import here to avoid cyclic import error when importing gda in pjit.py. from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources if not isinstance(mesh_axes, PartitionSpec): pspec = PartitionSpec(*mesh_axes) else: pspec = mesh_axes parsed_pspec, _, _ = _prepare_axis_resources(pspec, "mesh_axes") array_mapping = get_array_mapping(parsed_pspec) # The dtype doesn't matter for creating sharding specs. aval = core.ShapedArray(global_shape, np.float32) sharding_spec = pxla.mesh_sharding_specs( global_mesh.shape, global_mesh.axis_names)(aval, array_mapping) indices = pxla.spec_to_indices(global_shape, sharding_spec) for index in indices: assert isinstance(index, tuple) for idx in index: assert isinstance(idx, slice) # The type: ignore is to ignore the type returned by `spec_to_indices`. return dict((d, i) for d, i in safe_zip(global_mesh.devices.flat, indices)) # type: ignore
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) axis_env = xla.AxisEnv(nrep, (), ()) unordered_effects = [ eff for eff in jaxpr.effects if eff not in core.ordered_effects ] ordered_effects = [ eff for eff in jaxpr.effects if eff in core.ordered_effects ] module, _ = mlir.lower_jaxpr_to_module( f"spjit_{fun.__name__}", core.ClosedJaxpr(jaxpr, consts), unordered_effects, ordered_effects, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env), name_stack=new_name_stack(wrap_name(name, "sharded_jit")), donated_args=[False] * len(in_parts), arg_shardings=safe_map(xla.sharding_to_proto, in_parts), result_shardings=safe_map(xla.sharding_to_proto, out_parts)) built = xc._xla.mlir.mlir_module_to_xla_computation( mlir.module_to_string(module), use_tuple_args=False, return_tuple=True) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def devices_indices_map( self, global_shape: Shape) -> Mapping[Device, Optional[Index]]: indices = pxla.spec_to_indices(global_shape, self.sharding_spec) return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore
def _sharded_callable( fun: lu.WrappedFun, nparts: Optional[int], in_parts: Tuple[pxla.PartitionsOrReplicated, ...], out_parts_thunk: Callable[[], Tuple[pxla.PartitionsOrReplicated, ...]], local_in_parts: Optional[Tuple[pxla.PartitionsOrReplicated, ...]], local_out_parts_thunk: Callable[[], Optional[Tuple[ pxla.PartitionsOrReplicated, ...]]], local_nparts: Optional[int], name: str, *abstract_args): nrep = 1 if local_in_parts is None: local_in_parts = in_parts global_abstract_args = [ pxla.get_global_aval(arg, parts, lparts) for arg, parts, lparts in safe_zip( abstract_args, in_parts, local_in_parts) ] if logging.vlog_is_on(2): logging.vlog(2, "abstract_args: %s", abstract_args) logging.vlog(2, "global_abstract_args: %s", global_abstract_args) logging.vlog(2, "in_parts: %s", in_parts) logging.vlog(2, "local_in_parts: %s", local_in_parts) jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_abstract_args) platform = xb.get_backend().platform if platform not in ["tpu", "gpu"]: # TODO(skye): fall back to regular jit? raise ValueError(f"sharded_jit not supported for {platform}") nparts = pxla.reconcile_num_partitions(jaxpr, nparts) assert nparts is not None if nparts > xb.device_count(): raise ValueError( f"sharded_jit computation requires {nparts} devices, " f"but only {xb.device_count()} devices are available.") if xb.local_device_count() < nparts < xb.device_count(): raise NotImplementedError( f"sharded_jit across multiple hosts must use all available devices. " f"Got {nparts} out of {xb.device_count()} requested devices " f"(local device count: {xb.local_device_count()})") if local_nparts is None: if nparts > xb.local_device_count(): raise ValueError( "Specify 'local_nparts' when using cross-process sharded_jit " "and all inputs and outputs are replicated.") else: local_nparts = nparts if local_nparts > xb.local_device_count(): raise ValueError( f"sharded_jit computation requires {local_nparts} local devices, " f"but only {xb.local_device_count()} local devices are available.") if logging.vlog_is_on(2): logging.vlog(2, "nparts: %d local_nparts: %d", nparts, local_nparts) out_parts = out_parts_thunk() local_out_parts = local_out_parts_thunk() if local_out_parts is None: local_out_parts = out_parts if logging.vlog_is_on(2): logging.vlog(2, "out_parts: %s", out_parts) logging.vlog(2, "local_out_parts: %s", local_out_parts) local_out_avals = [ pxla.get_local_aval(out, parts, lparts) for out, parts, lparts in safe_zip( global_out_avals, out_parts, local_out_parts) ] log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s for %d devices with args %s.", fun.__name__, nparts, global_abstract_args) c = xc.XlaBuilder("spjit_{}".format(fun.__name__)) xla_consts = _map(partial(xla.pyval_to_ir_constant, c), consts) xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) axis_env = xla.AxisEnv(nrep, (), ()) ctx = xla.TranslationContext( c, platform, axis_env, extend_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(ctx, jaxpr, xla_consts, *xla_args) out_tuple = xla.with_sharding(c, out_parts, xops.Tuple, c, out_nodes) built = c.Build(out_tuple) if nparts <= xb.local_device_count(): devices = xb.local_devices()[:nparts] else: assert nparts == xb.device_count() devices = xb.devices() device_assignment = np.array([[d.id for d in devices]]) device_assignment = np.reshape(device_assignment, (-1, nparts)) # device_assignment = None # TODO(skye): replace with default device assignment? compiled = dispatch.backend_compile( xb.get_backend(), built, xb.get_compile_options(nrep, nparts, device_assignment)) input_specs = [ pxla.partitioned_sharding_spec(local_nparts, parts, aval) for parts, aval in zip(local_in_parts, abstract_args) ] input_indices = [ pxla.spec_to_indices(aval.shape, spec) if spec is not None else None for aval, spec in zip(abstract_args, input_specs) ] handle_args = partial(pxla.shard_args, compiled.local_devices(), input_indices) handle_outs = _avals_to_results_handler( nrep, local_nparts, # type: ignore local_out_parts, local_out_avals) return partial(_execute_spatially_partitioned, compiled, handle_args, handle_outs)
def solve_shape_vars(shape_spec: str, shape: Sequence[int]) -> Dict[str, int]: shape_polys = masking.parse_spec(shape_spec) return jax2tf.jax2tf._solve_shape_vars( util.safe_zip(shape_polys, shape))
def tree_reduce(function: Callable[[T, Any], T], tree: Any, initializer: Any = no_initializer) -> T: if initializer is no_initializer: return functools.reduce(function, tree_leaves(tree)) else: return functools.reduce(function, tree_leaves(tree), initializer) def tree_all(tree): return all(tree_leaves(tree)) register_pytree_node( collections.OrderedDict, lambda x: (tuple(x.values()), tuple(x.keys())), lambda keys, values: collections.OrderedDict(safe_zip(keys, values))) register_pytree_node( collections.defaultdict, lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))), lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) # type: ignore[index] class _HashableCallableShim: """Object that delegates __call__, __hash__, and __eq__ to another object.""" def __init__(self, fun): self.fun = fun def __call__(self, *args, **kw): return self.fun(*args, **kw)