def testNestedPmapConstantDevices(self): raise SkipTest("Nested pmaps with devices not yet implemented") if xla_bridge.device_count() < 6: raise SkipTest("this test requires >= 6 devices") devices = xla_bridge.devices()[:-2] shuffle(devices) f = pmap(pmap(lambda x: 3), devices=devices) shape = (2, len(devices) // 2, 3) x = np.arange(prod(shape)).reshape(shape) ans = f(x) expected = 3 * onp.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. expected_sharded = pmap(pmap(lambda x: x), devices=devices)(expected) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in expected_sharded.device_buffers])
def testPsumMultiple(self): f = lambda x: lax.psum(x, ('i', 'j')) f = pmap(pmap(f, 'i'), 'j') def sum_and_broadcast(x, axis): return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis], axis) device_count = xla_bridge.device_count() num_pairs, ragged = divmod(device_count, 2) if num_pairs > 1 and not ragged: shape = (num_pairs, 2, 4) else: shape = (device_count, 1, 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) ans = f(x) expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1) self.assertAllClose(ans, expected, check_dtypes=False)
def testShardedDeviceTuple(self): f = lambda x: core.pack((x, x)) f = pmap(f) shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) # test that we can pass in and out ShardedDeviceTuples (and unpack them) y = f(x) self.assertIsInstance(y, pxla.ShardedDeviceTuple) self.assertIsInstance(y, core.JaxTuple) self.assertAllClose(y, (x, x), check_dtypes=False) z = f(y) self.assertIsInstance(z, pxla.ShardedDeviceTuple) self.assertAllClose(z, (y, y), check_dtypes=True) # test that we can pass a ShardedDeviceTuple to a regular jit computation w = jit(lambda x: list(x)[0])(y) self.assertAllClose(w, x, check_dtypes=False)
def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=t.dtype) mask = lax.concatenate([ full(1.0, shape=(1, )), full(2.0, shape=(n - 2 + is_odd, )), full(1.0, shape=(1 - is_odd, )) ], dimension=0) scale = 1 / prod(fft_lengths) out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) return out
def testShardingConstraint(self): if jax.local_device_count() < 2: raise SkipTest("requires 2 devices") def f(x): y = x + 1 y = with_sharding_constraint(y, P(1, 2)) return y * 2 shape = (8, 8) x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is # the default. self.assertEqual( getattr(actual.device_buffers[0], "xla_shape", actual.device_buffers[0].shape)().dimensions(), (4, 8)) self.assertEqual( getattr(actual.device_buffers[1], "xla_shape", actual.device_buffers[1].shape)().dimensions(), (4, 8)) # Mismatched sharded_jit partitions with self.assertRaisesRegex( ValueError, r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) " r"\(total partitions: 2\) doesn't match expected number of partitions: " r"4. If these partitions look right, check outer sharded_jit and/or " r"other with_sharding_constraint calls."): sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x) # Replicated sharded_jit actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual.device_buffers, 2) self.assertAllClose(actual.device_buffers[0].to_py(), actual.device_buffers[1].to_py(), check_dtypes=False)
def testGradOfShardingConstraint(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") @partial(sharded_jit, in_parts=P(4, 1), out_parts=None) def f(x): y = x + 1 p, vjp_f = vjp( lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y) return vjp_f(p) def expected_f(x): y = x + 1 p, vjp_f = vjp(lambda z: jnp.sin(z), y) return vjp_f(p) shape = (4, 4) x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape) actual = f(x) expected = expected_f(x) self.assertAllClose(actual, expected, check_dtypes=False)
def testManyArgs(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") num_args = 200 def f(*args): return jnp.asarray(args).sum() shape = (2, 4, 4) args = [np.arange(prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1), ) * num_args out_partitions = None result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(*args) expected = pmap(f)(*args) self.assertAllClose(result, expected, check_dtypes=False) self.assertTrue(isinstance(result, pxla.ShardedDeviceArray)) self.assertEqual(len(result.device_buffers), 4)
def testShardingConstraint(self): if jax.local_device_count() < 4: raise SkipTest("requires 4 devices") @partial(sharded_jit, in_parts=None, out_parts=None) def f(x): y = jnp.dot(x, x) y = with_sharding_constraint(y, P(2, 1)) return y * 2 def expected_f(x): return jnp.dot(x, x) * 2 shape = (2, 8, 8) x = np.arange(prod(shape)).reshape(shape) result = pmap(f)(x) expected = pmap(expected_f)(x) self.assertAllClose(result, expected, check_dtypes=False) self.assertIsInstance(result, pxla.ShardedDeviceArray) self.assertLen(result.device_buffers, 4)
def testInAxesNone(self): shape = (4, 4) replicas = 2 in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) def f(x, y, _): return x @ y result = pmap(sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions), in_axes=in_axes)(x, y, dummy) expected = pmap(f, in_axes=in_axes)(x, y, dummy) self.assertAllClose(result, expected, check_dtypes=True)
def testNestedPmapConstant(self): if xla_bridge.device_count() == 1: raise SkipTest("this test requires multiple devices") f = pmap(pmap(lambda x: 3)) shape = (2, xla_bridge.device_count() // 2, 3) x = np.arange(prod(shape)).reshape(shape) ans = f(x) expected = 3 * onp.ones(shape[:2]) self.assertAllClose(ans, expected, check_dtypes=False) # Test that 'ans' was properly replicated across devices. expected_sharded = pmap(pmap(lambda x: x))(expected) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in expected_sharded.device_buffers]) f = pmap(pmap(lambda x: (x, 3))) x_sharded, ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) self.assertEqual([b.device() for b in ans.device_buffers], [b.device() for b in x_sharded.device_buffers])
def omnistaging_disabler() -> None: global axis_index psum_p.bind = partial(core.Primitive.bind, psum_p) psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore pxla.parallel_pure_rules[psum_p] = lambda *args, shape: ( x * prod(shape) for x in args) # type: ignore def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame) + 1] nreps = dynamic_axis_env.nreps trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer def _axis_index_translation_rule(c, nreps, sizes, axis_name): div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval(lambda *args, **params: ShapedArray( (), np.int32)) xla.translations[axis_index_p] = _axis_index_translation_rule
def _triangular_solve_gpu_translation_rule( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): shape = c.get_shape(a) dims = shape.dimensions() m, n = dims[-2:] batch = prod(dims[:-2]) if conjugate_a and not transpose_a: a = xops.Conj(a) conjugate_a = False if batch > 1 and m <= 32 and n <= 32: return cusolver.trsm( c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal) else: # Use the XLA implementation for unbatched triangular_solve. if not transpose_a: transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE else: transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a else xops.TriangularSolveOptions_Transpose.TRANSPOSE) return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
def testPartiallyMapped(self): f = pmap(lambda x, y: x, in_axes=(None, 0)) g = pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0)) mesh_shape = (xla_bridge.device_count(),) shape = mesh_shape + (4,) x = onp.array(3., dtype=onp.float32) y = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) f_expected = onp.broadcast_to(x, mesh_shape) f_ans = f(x, y) self.assertAllClose(f_ans, f_expected, check_dtypes=True) self.assertIsInstance(f_ans, pxla.ShardedDeviceArray) # the output is actually replicated (has the same values in each device buffer) # but out_axes is implicitly 0, so we shouldn't have replication in the # sharding spec. self.assertEqual(f_ans.sharding_spec.replication_factor, 1) g_expected = onp.broadcast_to(x - onp.sum(y, 0, keepdims=True), shape) g_ans = g(x, y) self.assertAllClose(g_ans, g_expected, check_dtypes=True) self.assertIsInstance(g_ans, pxla.ShardedDeviceArray) self.assertEqual(g_ans.sharding_spec.replication_factor, 1)
def testCollectiveConstantNested(self): device_count = xla_bridge.device_count() @partial(pmap, axis_name='i') def f(x): @partial(pmap, axis_name='j') def g(y): a = lax.psum(1, 'i') b = lax.psum(1, 'j') c = lax.psum(1, ('i', 'j')) return a, b, c return g(x) shape = (device_count, 1, 4) x = np.arange(prod(shape)).reshape(shape) a, b, c = f(x) self.assertEqual(a.shape, shape[:-1]) self.assertEqual(b.shape, shape[:-1]) self.assertEqual(c.shape, shape[:-1]) self.assertEqual(a.ravel()[0], device_count) self.assertEqual(b.ravel()[0], 1) self.assertEqual(c.ravel()[0], device_count * 1)
def _make_arg(*shape): return np.arange(prod(shape)).reshape(shape)
def conv_general_dilated_patches( lhs: lax.Array, filter_shape: Sequence[int], window_strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], lhs_dilation: Sequence[int] = None, rhs_dilation: Sequence[int] = None, dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None, precision: lax.PrecisionType = None, ) -> lax.Array: """Extract patches subject to the receptive field of `conv_general_dilated`. Runs the input through a convolution with given parameters. The kernel of the convolution is constructed such that the output channel dimension `"C"` contains flattened image patches, so instead a single `"C"` dimension represents, for example, three dimensions `"chw"` collapsed. The order of these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`, where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"` dimension is therefore the size of each patch, i.e. `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where `lhs_spec == dimension_numbers[0]`. Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: https://www.tensorflow.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. filter_shape: a sequence of `n` integers, representing the receptive window spatial shape in the order as specified in `rhs_spec = dimension_numbers[1]`. window_strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `lhs`. LHS dilation is also known as transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. dimension_numbers: either `None`, or a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`. precision: Optional. Either ``None``, which means the default precision for the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). Returns: A rank `n+2` array containing the flattened image patches in the output channel (`"C"`) dimension. For example if `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to the size of each patch (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). """ filter_shape = tuple(filter_shape) dimension_numbers = lax.conv_dimension_numbers( lhs.shape, (1, 1) + filter_shape, dimension_numbers) lhs_spec, rhs_spec, out_spec = dimension_numbers spatial_size = prod(filter_shape) n_channels = lhs.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2) rhs = rhs.reshape((spatial_size, 1) + filter_shape) rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1)) rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = lax.conv_general_dilated( lhs=lhs, rhs=rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, precision=None if precision is None else (precision, lax.Precision.DEFAULT), feature_group_count=n_channels ) return out
def args_maker(): flat_keys = np.arange(prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values
def testTopKGrad(self, shape, dtype, k, rng_factory): flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
def benchmark_fn(): arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape)) indices = indices_fn() for idx in indices: arr[idx]
psum = partial(_allreduce_translation_rule, lax.add_p, c, replica_groups=replica_groups) dtype = c.GetShape(val).numpy_dtype() if dtypes.issubdtype(dtype, onp.complexfloating): return c.Complex(psum(c.Real(val)), psum(c.Imag(val))) else: return psum(val) psum_p = standard_pmap_primitive('psum') pxla.split_axis_rules[psum_p] = \ partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda x, shape: x * prod(shape) ad.deflinear(psum_p, lambda t, axis_name: [psum(t, axis_name)]) pxla.multi_host_supported_collectives.add(psum_p) pmax_p = standard_pmap_primitive('pmax') xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pxla.split_axis_rules[pmax_p] = \ partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max) pmin_p = standard_pmap_primitive('pmin') xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p) pxla.split_axis_rules[pmin_p] = \ partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
replica_groups=replica_groups) dtype = c.get_shape(val).numpy_dtype() if dtypes.issubdtype(dtype, onp.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val) return xops.Tuple(c, list(map(_translate, args))) psum_p = standard_pmap_primitive('psum', multiple_results=True) psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.split_axis_rules[psum_p] = \ partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = _psum_translation_rule pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) ad.deflinear( psum_p, lambda ts, axis_name, axis_index_groups: psum_p.bind( *ts, axis_name=axis_name, axis_index_groups=axis_index_groups)) pxla.multi_host_supported_collectives.add(psum_p) pmax_p = standard_pmap_primitive('pmax') xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) pxla.split_axis_rules[pmax_p] = \ partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max) pmin_p = standard_pmap_primitive('pmin') xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p)
def _axis_index_translation_rule(c, nreps, sizes, axis_name): div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def testMakeJaxprOfOpenSpmd(self): f = lambda x: x - lax.psum(x, 'i') shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) make_jaxpr(f)(x) # doesn't crash