Exemple #1
0
    def testPmapConstant(self):
        device_count = xla_bridge.device_count()
        f = pmap(lambda x: 3)
        x = np.arange(device_count)
        with jtu.count_jit_and_pmap_compiles() as count:
            ans = f(x)
        self.assertEqual(count[0], 0)
        expected = onp.repeat(3, device_count)
        self.assertAllClose(ans, expected, check_dtypes=False)

        f = pmap(lambda x: (x, 3))
        with jtu.count_jit_and_pmap_compiles() as count:
            _, ans = f(x)
        self.assertEqual(count[0], 1)
        self.assertAllClose(ans, expected, check_dtypes=False)
Exemple #2
0
    def testCompilationCache(self):
        f = lambda x: x + 1
        sharded_f = sharded_jit(f, in_parts=P(2), out_parts=P(2))
        shape = (2, )
        x = np.arange(prod(shape), dtype=np.float32).reshape(shape)

        with jtu.count_jit_and_pmap_compiles() as count:
            sharded_f(x)
            sharded_f(x)
        self.assertEqual(count[0], 1)
Exemple #3
0
  def testPmapConstantDevices(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    devices = xla_bridge.devices()[:-1]
    shuffle(devices)
    f = pmap(lambda x: 3, devices=devices)
    x = np.arange(len(devices))
    with jtu.count_jit_and_pmap_compiles() as count:
      ans = f(x)
    self.assertEqual(count[0], 0)
    expected = onp.repeat(3, len(devices))
    self.assertAllClose(ans, expected, check_dtypes=False)

    # Test that 'ans' was properly replicated across devices.
    self.assertEqual([b.device() for b in ans.device_buffers], devices)
Exemple #4
0
  def testNestedPmapConstantDevices(self):
    raise SkipTest("Nested pmaps with devices not yet implemented")

    if xla_bridge.device_count() < 6:
      raise SkipTest("this test requires >= 6 devices")

    devices = xla_bridge.devices()[:-2]
    shuffle(devices)
    f = pmap(pmap(lambda x: 3), devices=devices)
    shape = (2, len(devices) // 2, 3)
    x = np.arange(prod(shape)).reshape(shape)
    with jtu.count_jit_and_pmap_compiles() as count:
      ans = f(x)
    self.assertEqual(count[0], 0)
    expected = 3 * onp.ones(shape[:2])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # Test that 'ans' was properly replicated across devices.
    expected_sharded = pmap(pmap(lambda x: x), devices=devices)(expected)
    self.assertEqual([b.device() for b in ans.device_buffers],
                     [b.device() for b in expected_sharded.device_buffers])
Exemple #5
0
  def testNestedPmapConstant(self):
    if xla_bridge.device_count() == 1:
      raise SkipTest("this test requires multiple devices")

    f = pmap(pmap(lambda x: 3))
    shape = (2, xla_bridge.device_count() // 2, 3)
    x = np.arange(prod(shape)).reshape(shape)
    with jtu.count_jit_and_pmap_compiles() as count:
      ans = f(x)
    self.assertEqual(count[0], 0)
    expected = 3 * onp.ones(shape[:2])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # Test that 'ans' was properly replicated across devices.
    expected_sharded = pmap(pmap(lambda x: x))(expected)
    self.assertEqual([b.device() for b in ans.device_buffers],
                     [b.device() for b in expected_sharded.device_buffers])

    f = pmap(pmap(lambda x: (x, 3)))
    x_sharded, ans = f(x)
    self.assertAllClose(ans, expected, check_dtypes=False)
    self.assertEqual([b.device() for b in ans.device_buffers],
                     [b.device() for b in x_sharded.device_buffers])