Esempio n. 1
0
 def test_collective_reduce_async_scope(self):
     # Note that ops on the parallel device currently don't execute
     # asynchronously. The test is just that we don't get deadlocks.
     x = self.device.pack(
         [constant_op.constant(-1.5),
          constant_op.constant(3.5)])
     with context.async_scope(), self.device:
         reduced = _collective_sum(x, num_replicas=2)
         outputs = self.device.unpack(reduced)
     self.assertAllClose([2., 2.], outputs)
     self.assertIn(self.device.components[0], outputs[0].backing_device)
     self.assertIn(self.device.components[1], outputs[1].backing_device)
Esempio n. 2
0
    def test_out_of_range_with_async_scope(self):

        with ops.device('/job:worker/task:0'):
            dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
            dataset = dataset.batch(1, drop_remainder=False)
            iterator = iter(dataset)
            v = variables.Variable(1.0)

        @def_function.function
        def train_step(iterator):
            i = next(iterator)
            v.assign_add(math_ops.reduce_mean(i))

        num_steps = 3
        try:
            with context.async_scope():
                for _ in range(num_steps):
                    with ops.device('/job:worker/task:0'):
                        train_step(iterator)
        except errors.OutOfRangeError:
            context.async_clear_error()

        self.assertAllEqual(v.numpy(), 4.0)