Exemple #1
0
    def testNestedPmapConstantDevices(self):
        raise SkipTest("Nested pmaps with devices not yet implemented")

        if xla_bridge.device_count() < 6:
            raise SkipTest("this test requires >= 6 devices")

        devices = xla_bridge.devices()[:-2]
        shuffle(devices)
        f = pmap(pmap(lambda x: 3), devices=devices)
        shape = (2, len(devices) // 2, 3)
        x = np.arange(prod(shape)).reshape(shape)
        ans = f(x)
        expected = 3 * onp.ones(shape[:2])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # Test that 'ans' was properly replicated across devices.
        expected_sharded = pmap(pmap(lambda x: x), devices=devices)(expected)
        self.assertEqual([b.device() for b in ans.device_buffers],
                         [b.device() for b in expected_sharded.device_buffers])
Exemple #2
0
    def testPsumMultiple(self):
        f = lambda x: lax.psum(x, ('i', 'j'))
        f = pmap(pmap(f, 'i'), 'j')

        def sum_and_broadcast(x, axis):
            return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis],
                              axis)

        device_count = xla_bridge.device_count()
        num_pairs, ragged = divmod(device_count, 2)
        if num_pairs > 1 and not ragged:
            shape = (num_pairs, 2, 4)
        else:
            shape = (device_count, 1, 4)
        x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

        ans = f(x)
        expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #3
0
    def testShardedDeviceTuple(self):
        f = lambda x: core.pack((x, x))
        f = pmap(f)

        shape = (xla_bridge.device_count(), 4)
        x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

        # test that we can pass in and out ShardedDeviceTuples (and unpack them)
        y = f(x)
        self.assertIsInstance(y, pxla.ShardedDeviceTuple)
        self.assertIsInstance(y, core.JaxTuple)
        self.assertAllClose(y, (x, x), check_dtypes=False)
        z = f(y)
        self.assertIsInstance(z, pxla.ShardedDeviceTuple)
        self.assertAllClose(z, (y, y), check_dtypes=True)

        # test that we can pass a ShardedDeviceTuple to a regular jit computation
        w = jit(lambda x: list(x)[0])(y)
        self.assertAllClose(w, x, check_dtypes=False)
Exemple #4
0
def _irfft_transpose(t, fft_lengths):
    # The transpose of IRFFT is the RFFT of the cotangent times a scaling
    # factor and a mask. The mask scales the cotangent for the Hermitian
    # symmetric components of the RFFT by a factor of two, since these components
    # are de-duplicated in the RFFT.
    x = fft(t, xla_client.FftType.RFFT, fft_lengths)
    n = x.shape[-1]
    is_odd = fft_lengths[-1] % 2
    full = partial(lax.full_like, t, dtype=t.dtype)
    mask = lax.concatenate([
        full(1.0, shape=(1, )),
        full(2.0, shape=(n - 2 + is_odd, )),
        full(1.0, shape=(1 - is_odd, ))
    ],
                           dimension=0)
    scale = 1 / prod(fft_lengths)
    out = scale * mask * x
    assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
    return out
Exemple #5
0
    def testShardingConstraint(self):
        if jax.local_device_count() < 2:
            raise SkipTest("requires 2 devices")

        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(1, 2))
            return y * 2

        shape = (8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        expected = (x + 1) * 2

        # Matching sharded_jit partitions
        actual = sharded_jit(f, in_parts=P(2, 1), out_parts=P(2, 1))(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        # TODO(jblespiau): We can simply use buf.xla_shape() when version 0.1.58 is
        # the default.
        self.assertEqual(
            getattr(actual.device_buffers[0], "xla_shape",
                    actual.device_buffers[0].shape)().dimensions(), (4, 8))
        self.assertEqual(
            getattr(actual.device_buffers[1], "xla_shape",
                    actual.device_buffers[1].shape)().dimensions(), (4, 8))

        # Mismatched sharded_jit partitions
        with self.assertRaisesRegex(
                ValueError,
                r"with_sharding_constraint with partitions=PartitionSpec\(1, 2\) "
                r"\(total partitions: 2\) doesn't match expected number of partitions: "
                r"4. If these partitions look right, check outer sharded_jit and/or "
                r"other with_sharding_constraint calls."):
            sharded_jit(f, in_parts=P(2, 2), out_parts=P(2, 2))(x)

        # Replicated sharded_jit
        actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
        self.assertLen(actual.device_buffers, 2)
        self.assertAllClose(actual.device_buffers[0].to_py(),
                            actual.device_buffers[1].to_py(),
                            check_dtypes=False)
Exemple #6
0
    def testGradOfShardingConstraint(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        @partial(sharded_jit, in_parts=P(4, 1), out_parts=None)
        def f(x):
            y = x + 1
            p, vjp_f = vjp(
                lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y)
            return vjp_f(p)

        def expected_f(x):
            y = x + 1
            p, vjp_f = vjp(lambda z: jnp.sin(z), y)
            return vjp_f(p)

        shape = (4, 4)
        x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape)
        actual = f(x)
        expected = expected_f(x)
        self.assertAllClose(actual, expected, check_dtypes=False)
Exemple #7
0
    def testManyArgs(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        num_args = 200

        def f(*args):
            return jnp.asarray(args).sum()

        shape = (2, 4, 4)
        args = [np.arange(prod(shape)).reshape(shape)] * num_args
        in_partitions = (P(2, 1), ) * num_args
        out_partitions = None
        result = pmap(
            sharded_jit(f, in_parts=in_partitions,
                        out_parts=out_partitions))(*args)
        expected = pmap(f)(*args)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
        self.assertEqual(len(result.device_buffers), 4)
Exemple #8
0
    def testShardingConstraint(self):
        if jax.local_device_count() < 4:
            raise SkipTest("requires 4 devices")

        @partial(sharded_jit, in_parts=None, out_parts=None)
        def f(x):
            y = jnp.dot(x, x)
            y = with_sharding_constraint(y, P(2, 1))
            return y * 2

        def expected_f(x):
            return jnp.dot(x, x) * 2

        shape = (2, 8, 8)
        x = np.arange(prod(shape)).reshape(shape)
        result = pmap(f)(x)
        expected = pmap(expected_f)(x)

        self.assertAllClose(result, expected, check_dtypes=False)
        self.assertIsInstance(result, pxla.ShardedDeviceArray)
        self.assertLen(result.device_buffers, 4)
Exemple #9
0
    def testInAxesNone(self):
        shape = (4, 4)
        replicas = 2
        in_partitions = (P(2, 1), None, None)
        out_partitions = P(2, 1)
        in_axes = (None, None, 0)
        x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
        dummy = np.arange(replicas, dtype=np.float32) + 1
        num_shards = replicas * np.prod(in_partitions[0])
        if num_shards > jax.local_device_count():
            raise SkipTest("requires %d devices" % num_shards)

        def f(x, y, _):
            return x @ y

        result = pmap(sharded_jit(f,
                                  in_parts=in_partitions,
                                  out_parts=out_partitions),
                      in_axes=in_axes)(x, y, dummy)
        expected = pmap(f, in_axes=in_axes)(x, y, dummy)
        self.assertAllClose(result, expected, check_dtypes=True)
Exemple #10
0
    def testNestedPmapConstant(self):
        if xla_bridge.device_count() == 1:
            raise SkipTest("this test requires multiple devices")

        f = pmap(pmap(lambda x: 3))
        shape = (2, xla_bridge.device_count() // 2, 3)
        x = np.arange(prod(shape)).reshape(shape)
        ans = f(x)
        expected = 3 * onp.ones(shape[:2])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # Test that 'ans' was properly replicated across devices.
        expected_sharded = pmap(pmap(lambda x: x))(expected)
        self.assertEqual([b.device() for b in ans.device_buffers],
                         [b.device() for b in expected_sharded.device_buffers])

        f = pmap(pmap(lambda x: (x, 3)))
        x_sharded, ans = f(x)
        self.assertAllClose(ans, expected, check_dtypes=False)
        self.assertEqual([b.device() for b in ans.device_buffers],
                         [b.device() for b in x_sharded.device_buffers])
Exemple #11
0
def omnistaging_disabler() -> None:
    global axis_index

    psum_p.bind = partial(core.Primitive.bind, psum_p)
    psum_p.def_impl(partial(pxla.apply_parallel_primitive,
                            psum_p))  # type: ignore
    pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (
        x * prod(shape) for x in args)  # type: ignore

    def _axis_index_bind(*, axis_name):
        dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
        frame = dynamic_axis_env[axis_name]
        sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame) + 1]
        nreps = dynamic_axis_env.nreps
        trace = frame.pmap_trace

        out_aval = ShapedArray((), np.int32)
        out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval),
                                    None)
        eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                                dict(nreps=nreps,
                                     sizes=sizes,
                                     axis_name=axis_name),
                                source_info_util.current())
        out_tracer.recipe = eqn

        return out_tracer

    def _axis_index_translation_rule(c, nreps, sizes, axis_name):
        div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
        mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
        unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
        return xops.ConvertElementType(unsigned_index,
                                       xb.dtype_to_etype(np.int32))

    axis_index_p.def_custom_bind(_axis_index_bind)
    axis_index_p.def_abstract_eval(lambda *args, **params: ShapedArray(
        (), np.int32))
    xla.translations[axis_index_p] = _axis_index_translation_rule
Exemple #12
0
def _triangular_solve_gpu_translation_rule(
    c, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
  shape = c.get_shape(a)
  dims = shape.dimensions()
  m, n = dims[-2:]
  batch = prod(dims[:-2])
  if conjugate_a and not transpose_a:
    a = xops.Conj(a)
    conjugate_a = False
  if batch > 1 and m <= 32 and n <= 32:
    return cusolver.trsm(
      c, a, b, left_side, lower, transpose_a,
      conjugate_a, unit_diagonal)
  else:
    # Use the XLA implementation for unbatched triangular_solve.
    if not transpose_a:
      transpose = xops.TriangularSolveOptions_Transpose.NO_TRANSPOSE
    else:
      transpose = (xops.TriangularSolveOptions_Transpose.ADJOINT if conjugate_a
                   else xops.TriangularSolveOptions_Transpose.TRANSPOSE)
    return xops.TriangularSolve(a, b, left_side, lower, unit_diagonal,
                                transpose)
Exemple #13
0
  def testPartiallyMapped(self):
    f = pmap(lambda x, y: x, in_axes=(None, 0))
    g = pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))

    mesh_shape = (xla_bridge.device_count(),)
    shape = mesh_shape + (4,)
    x = onp.array(3., dtype=onp.float32)
    y = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)

    f_expected = onp.broadcast_to(x, mesh_shape)
    f_ans = f(x, y)
    self.assertAllClose(f_ans, f_expected, check_dtypes=True)
    self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
    # the output is actually replicated (has the same values in each device buffer)
    # but out_axes is implicitly 0, so we shouldn't have replication in the
    # sharding spec.
    self.assertEqual(f_ans.sharding_spec.replication_factor, 1)

    g_expected = onp.broadcast_to(x - onp.sum(y, 0, keepdims=True), shape)
    g_ans = g(x, y)
    self.assertAllClose(g_ans, g_expected, check_dtypes=True)
    self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
    self.assertEqual(g_ans.sharding_spec.replication_factor, 1)
Exemple #14
0
  def testCollectiveConstantNested(self):
    device_count = xla_bridge.device_count()

    @partial(pmap, axis_name='i')
    def f(x):
      @partial(pmap, axis_name='j')
      def g(y):
        a = lax.psum(1, 'i')
        b = lax.psum(1, 'j')
        c = lax.psum(1, ('i', 'j'))
        return a, b, c
      return g(x)

    shape = (device_count, 1, 4)
    x = np.arange(prod(shape)).reshape(shape)
    a, b, c = f(x)

    self.assertEqual(a.shape, shape[:-1])
    self.assertEqual(b.shape, shape[:-1])
    self.assertEqual(c.shape, shape[:-1])

    self.assertEqual(a.ravel()[0], device_count)
    self.assertEqual(b.ravel()[0], 1)
    self.assertEqual(c.ravel()[0], device_count * 1)
Exemple #15
0
 def _make_arg(*shape):
     return np.arange(prod(shape)).reshape(shape)
Exemple #16
0
def conv_general_dilated_patches(
    lhs: lax.Array,
    filter_shape: Sequence[int],
    window_strides: Sequence[int],
    padding: Union[str, Sequence[Tuple[int, int]]],
    lhs_dilation: Sequence[int] = None,
    rhs_dilation: Sequence[int] = None,
    dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None,
    precision: lax.PrecisionType = None,
) -> lax.Array:
  """Extract patches subject to the receptive field of `conv_general_dilated`.

  Runs the input through a convolution with given parameters. The kernel of the
  convolution is constructed such that the output channel dimension `"C"`
  contains flattened image patches, so instead a single `"C"` dimension
  represents, for example, three dimensions `"chw"` collapsed. The order of
  these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`,
  where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"`
  dimension is therefore the size of each patch, i.e.
  `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where
  `lhs_spec == dimension_numbers[0]`.

  Docstring below adapted from `jax.lax.conv_general_dilated`.

  See Also:
    https://www.tensorflow.org/xla/operation_semantics#conv_convolution

  Args:
    lhs: a rank `n+2` dimensional input array.
    filter_shape: a sequence of `n` integers, representing the receptive window
      spatial shape in the order as specified in
      `rhs_spec = dimension_numbers[1]`.
    window_strides: a sequence of `n` integers, representing the inter-window
      strides.
    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
      `n` `(low, high)` integer pairs that give the padding to apply before and
      after each spatial dimension.
    lhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
      is also known as transposed convolution.
    rhs_dilation: `None`, or a sequence of `n` integers, giving the
      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
      is also known as atrous convolution.
    dimension_numbers: either `None`, or a 3-tuple
      `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
      of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
    precision: Optional. Either ``None``, which means the default precision for
      the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
      ``Precision.HIGH`` or ``Precision.HIGHEST``).

  Returns:
    A rank `n+2` array containing the flattened image patches in the output
    channel (`"C"`) dimension. For example if
    `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension
    numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to
    the size of each patch
    (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).

  """
  filter_shape = tuple(filter_shape)
  dimension_numbers = lax.conv_dimension_numbers(
      lhs.shape, (1, 1) + filter_shape, dimension_numbers)

  lhs_spec, rhs_spec, out_spec = dimension_numbers

  spatial_size = prod(filter_shape)
  n_channels = lhs.shape[lhs_spec[1]]

  # Move separate `lhs` spatial locations into separate `rhs` channels.
  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)

  rhs = rhs.reshape((spatial_size, 1) + filter_shape)
  rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1))
  rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))

  out = lax.conv_general_dilated(
      lhs=lhs,
      rhs=rhs,
      window_strides=window_strides,
      padding=padding,
      lhs_dilation=lhs_dilation,
      rhs_dilation=rhs_dilation,
      dimension_numbers=dimension_numbers,
      precision=None if precision is None else (precision,
                                                lax.Precision.DEFAULT),
      feature_group_count=n_channels
  )
  return out
Exemple #17
0
 def args_maker():
   flat_keys = np.arange(prod(shape), dtype=key_dtype)
   keys = self.rng().permutation(flat_keys).reshape(shape)
   values = rng(shape, val_dtype)
   return keys, values
Exemple #18
0
 def testTopKGrad(self, shape, dtype, k, rng_factory):
   flat_values = np.arange(prod(shape), dtype=dtype)
   values = self.rng().permutation(flat_values).reshape(shape)
   fun = lambda vs: lax.top_k(vs, k=k)[0]
   check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
Exemple #19
0
 def benchmark_fn():
   arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape))
   indices = indices_fn()
   for idx in indices:
     arr[idx]
Exemple #20
0
    psum = partial(_allreduce_translation_rule,
                   lax.add_p,
                   c,
                   replica_groups=replica_groups)
    dtype = c.GetShape(val).numpy_dtype()
    if dtypes.issubdtype(dtype, onp.complexfloating):
        return c.Complex(psum(c.Real(val)), psum(c.Imag(val)))
    else:
        return psum(val)


psum_p = standard_pmap_primitive('psum')
pxla.split_axis_rules[psum_p] = \
    partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda x, shape: x * prod(shape)
ad.deflinear(psum_p, lambda t, axis_name: [psum(t, axis_name)])
pxla.multi_host_supported_collectives.add(psum_p)

pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
    partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)

pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
pxla.split_axis_rules[pmin_p] = \
    partial(_allreduce_split_axis_rule, pmin_p, lax._reduce_min)
Exemple #21
0
                       replica_groups=replica_groups)
        dtype = c.get_shape(val).numpy_dtype()
        if dtypes.issubdtype(dtype, onp.complexfloating):
            return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val)))
        else:
            return psum(val)

    return xops.Tuple(c, list(map(_translate, args)))


psum_p = standard_pmap_primitive('psum', multiple_results=True)
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
pxla.split_axis_rules[psum_p] = \
    partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape)
                                                         for x in args)
ad.deflinear(
    psum_p, lambda ts, axis_name, axis_index_groups: psum_p.bind(
        *ts, axis_name=axis_name, axis_index_groups=axis_index_groups))
pxla.multi_host_supported_collectives.add(psum_p)

pmax_p = standard_pmap_primitive('pmax')
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
pxla.split_axis_rules[pmax_p] = \
    partial(_allreduce_split_axis_rule, pmax_p, lax._reduce_max)

pmin_p = standard_pmap_primitive('pmin')
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
Exemple #22
0
 def _axis_index_translation_rule(c, nreps, sizes, axis_name):
   div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
   mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
   unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
   return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
Exemple #23
0
 def testMakeJaxprOfOpenSpmd(self):
   f = lambda x: x - lax.psum(x, 'i')
   shape = (xla_bridge.device_count(), 4)
   x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
   make_jaxpr(f)(x)  # doesn't crash