Esempio n. 1
0
  def test_pmap(self):
    vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)

    pmap_fun1 = api.pmap(fun1, axis_name="i")
    with hcb.outfeed_receiver(receiver_name=self._testMethodName):
      res = pmap_fun1(vargs)
    expected_res = jnp.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
    self.assertAllClose(expected_res, res, check_dtypes=False)
Esempio n. 2
0
    def test_pmap(self):
        vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)

        pmap_fun1 = api.pmap(fun1, axis_name="i")
        res = pmap_fun1(vargs)
        hcb.barrier_wait()
        expected_res = jnp.stack(
            [fun1_equiv(2. + a) for a in range(api.local_device_count())])
        self.assertAllClose(expected_res, res, check_dtypes=False)
Esempio n. 3
0
 def test_pmap_error_no_receiver(self):
     # Check for errors if starting jit without a consumer active
     vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
     with self.assertRaisesRegex(ValueError,
                                 "outfeed_receiver is not started"):
         api.pmap(lambda x: hcb.id_print(x))(vargs)