Exemple #1
0
    def testSoftPmapAxisIndex(self):
        n = 4 * xla_bridge.device_count()

        def f(x):
            return x * lax.axis_index('i')

        ans = soft_pmap(f, 'i')(2 * np.ones(n))
        expected = 2 * onp.arange(n)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #2
0
    def testSoftPmapPsum(self):
        n = 4 * xla_bridge.device_count()

        def f(x):
            return x / lax.psum(x, 'i')

        ans = soft_pmap(f, 'i')(np.ones(n))
        expected = onp.ones(n) / n
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #3
0
    def testSelect(self):
        p = onp.arange(15).reshape((5, 3)) % 4 == 1
        f = onp.zeros((5, 3))

        def fun(t):
            return lax.select(p, t, f)

        t = onp.ones((5, 3))
        ans = soft_pmap(*_papply(fun))(t)
        expected = fun(t)
        self.assertAllClose(ans, expected, check_dtypes=True)
Exemple #4
0
    def testMax(self):
        pfun, axis_name = _papply(lambda x: np.max(x, axis=0))

        jaxpr = make_jaxpr(pfun)(onp.ones(3))
        expected_jaxpr = make_jaxpr(lambda x: lax.pmax(x, axis_name))(
            onp.zeros((5, 3)))
        assert repr(jaxpr) == repr(expected_jaxpr)

        arg = onp.arange(15.).reshape((5, 3))
        ans = soft_pmap(pfun, axis_name)(arg)[0]
        expected = onp.max(arg, axis=0)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #5
0
    def testDot(self):
        raise SkipTest("known failure")  # TODO(frostig)
        x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2))

        def fun(x, y):
            return lax.dot(x, y)

        expected = fun(x, x)
        pfun, axis_name = _papply(fun)
        ans = soft_pmap(pfun, axis_name)(x, x)
        ans = self.dedup(ans, expected.ndim)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #6
0
    def testAddBroadcasting(self):
        raise SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x + 3

        x = onp.array([[1, 2], [3, 4]])
        expected = x + 3

        pfun, axis_name = _papply(fun)
        ans = soft_pmap(pfun, axis_name)(x)
        self.assertAllClose(ans, expected, check_dtypes=True)
Exemple #7
0
  def testSoftPmapDevicePersistence(self):
    device_count = xla_bridge.device_count()
    shape = (2 * 2 * device_count, 2, 3)

    # check that we can maintain device persistence across calls
    x = onp.arange(prod(shape)).reshape(shape)
    x = soft_pmap(lambda x: x)(x)
    self.assertIsInstance(x, pxla.ShardedDeviceArray)
    x._npy_value = onp.float32(onp.nan)  # can't be coerced to ndarray for xfer
    x = soft_pmap(lambda x: x)(x)  # doesn't crash
    self.assertIsInstance(x, pxla.ShardedDeviceArray)

    # check that we don't crash when we can't maintain device persistence
    x = onp.arange(prod(shape)).reshape(shape)
    x = soft_pmap(lambda x: x)(x)
    self.assertIsInstance(x, pxla.ShardedDeviceArray)
    y = x.reshape(device_count, -1)
    self.assertIsInstance(y, xla.DeviceArray)  # should have forced collection
    soft_pmap(lambda x: x)(y)  # doesn't crash
    z = x + 2
    self.assertIsInstance(z, xla.DeviceArray)  # should have forced collection
    x._npy_value = onp.float32(onp.nan)  # can't be coerced to ndarray for xfer
    self.assertRaisesRegex(
        RuntimeError,
        '.*does not match host shape or layout of computation parameter 0.*',
        lambda: x + 2)

    # check that different axis merges aren't a problem
    x = onp.arange(prod(shape)).reshape(shape)
    x = soft_pmap(lambda x: x)(x)
    self.assertIsInstance(x, pxla.ShardedDeviceArray)
    x = x.reshape(2 * device_count, 2, 2, 3)  # axis merge of the wrong size
    self.assertIsInstance(x, xla.DeviceArray)  # should have forced collection
Exemple #8
0
    def testDotGeneral(self, matching, coloring, split):
        BATCH, CONTRACT, _ = range(3)
        SPLIT_LHS, SPLIT_RHS, SPLIT_BOTH = range(3)

        x = onp.reshape(onp.arange(8.), (2, 2, 2))
        y = onp.reshape(onp.arange(8.), (2, 2, 2)) + 4.

        cdims = [(i, matching[i]) for i in range(3) if coloring[i] == CONTRACT]
        bdims = [(i, matching[i]) for i in range(3) if coloring[i] == BATCH]
        dimension_numbers = [
            list(zip(*cdims)) or [(), ()],
            list(zip(*bdims)) or [(), ()]
        ]

        def f(x, y):
            return lax.dot_general(x, y, dimension_numbers)

        if split == SPLIT_LHS:
            fun = lambda x: f(x, y)
        elif split == SPLIT_RHS:
            fun = lambda y: f(x, y)
        else:
            fun = f

        try:
            if split != SPLIT_BOTH:
                expected = fun(x)
                pfun, axis_name = _papply(fun)
                ans = soft_pmap(pfun, axis_name)(x)
            else:
                expected = fun(x, y)
                pfun, axis_name = _papply(fun)
                ans = soft_pmap(pfun, axis_name)(x, y)
        except (NotImplementedError, TypeError) as e:
            raise SkipTest(str(e)) from e

        ans = self.dedup(ans, expected.ndim)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #9
0
    def testLogSoftmax(self):
        raise SkipTest("test doesn't pass yet")  # TODO(frostig)

        def fun(x):
            return x - np.log(np.sum(np.exp(x)))

        pfun, axis_name = _papply(fun)

        jaxpr = make_jaxpr(pfun)(onp.zeros(5))
        expected_jaxpr = make_jaxpr(
            lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5))
        assert repr(jaxpr) == repr(expected_jaxpr)

        ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.))
        expected = fun(onp.arange(1., 5.))
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #10
0
def _soft_parallel(ker_fun, batch_size):
    """Returns a function that computes a kernel in batches in parallel.

  The current implementation uses jax.soft_pmap to simulate a given batch size.
  However, it is possible that larger batches will be chosen if
    (n1 * n2 // batch_size ** 2) > physical_device_count
  In this case soft_pmap will attempt implicitly use a larger batch size. To
  use a fixed batch size independent of the physical device count, one should
  compose this function with serial.

  In a future CL, it might be a good idea to introduce an `auto` batching
  function that composes serial and parallel automatically.

  Args:
    ker_fun: A function that computes a kernel between two datasets,
        ker_fun(x1, x2). Here x1 and x2 are `np.ndarray`s of floats of shape
        [n1,] + input_shape and [n2,] + input_shape. The kernel function
        should return a PyTree.
    batch_size: Integer specifying the size of batches in which to split the
        data.

  Returns:
    A new function with the same signature as ker_fun that computes the kernel
    by batching over the dataset in parallel with the specified batch_size.
  """

    ker_fun = soft_pmap(soft_pmap(ker_fun))

    def parallel_fn(x1, x2=None, *args, **kwargs):
        if x2 is None:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        n1 = x1.shape[0]
        n2 = x2.shape[0]
        input_shape = x1.shape[1:]

        n1_batches, ragged = divmod(n1, batch_size)
        if ragged:
            # TODO(schsam): Relax this constraint.
            raise ValueError((
                'Number of examples in x1 must divide batch size. Found |x1| = {} '
                'and batch size = {}').format(n1, batch_size))

        n2_batches, ragged = divmod(n2, batch_size)
        if ragged:
            # TODO(schsam): Relax this constraint.
            raise ValueError((
                'Number of examples in x2 must divide batch size. Found |x2| = {} '
                'and batch size = {}').format(n2, batch_size))

        x1s = np.reshape(x1, (
            n1_batches,
            1,
            batch_size,
        ) + input_shape)
        x1s = np.broadcast_to(x1s, (
            n1_batches,
            n2_batches,
            batch_size,
        ) + input_shape)

        x2s = np.reshape(x2, (
            1,
            n2_batches,
            batch_size,
        ) + input_shape)
        x2s = np.broadcast_to(x2s, (
            n1_batches,
            n2_batches,
            batch_size,
        ) + input_shape)

        kernel = ker_fun(x1s, x2s, *args, **kwargs)

        return _flatten_kernel(kernel)

    return parallel_fn