Ejemplo n.º 1
0
 def test_random_split_doesnt_device_put_during_tracing(self):
     if not config.omnistaging_enabled:
         raise SkipTest("test requires omnistaging")
     key = random.PRNGKey(1).block_until_ready()
     with jtu.count_device_put() as count:
         api.jit(random.split)(key)
     self.assertEqual(count[0], 1)  # 1 for the argument device_put
Ejemplo n.º 2
0
  def test_random_split_doesnt_device_put_during_tracing(self):
    raise SkipTest("broken test")  # TODO(mattjj): fix

    if not config.omnistaging_enabled:
      raise SkipTest("test is omnistaging-specific")

    key = random.PRNGKey(1)
    with jtu.count_device_put() as count:
      api.jit(random.split)(key)
      key, _ = random.split(key, 2)
    self.assertEqual(count[0], 1)  # 1 for the argument device_put call
Ejemplo n.º 3
0
 def test_random_split_doesnt_device_put_during_tracing(self):
   key = random.PRNGKey(1).block_until_ready()
   with jtu.count_device_put() as count:
     api.jit(random.split)(key)
   self.assertEqual(count[0], 1)  # 1 for the argument device_put