def test_collective_reduce_async_scope(self): # Note that ops on the parallel device currently don't execute # asynchronously. The test is just that we don't get deadlocks. x = self.device.pack( [constant_op.constant(-1.5), constant_op.constant(3.5)]) with context.async_scope(), self.device: reduced = _collective_sum(x, num_replicas=2) outputs = self.device.unpack(reduced) self.assertAllClose([2., 2.], outputs) self.assertIn(self.device.components[0], outputs[0].backing_device) self.assertIn(self.device.components[1], outputs[1].backing_device)
def test_out_of_range_with_async_scope(self): with ops.device('/job:worker/task:0'): dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) dataset = dataset.batch(1, drop_remainder=False) iterator = iter(dataset) v = variables.Variable(1.0) @def_function.function def train_step(iterator): i = next(iterator) v.assign_add(math_ops.reduce_mean(i)) num_steps = 3 try: with context.async_scope(): for _ in range(num_steps): with ops.device('/job:worker/task:0'): train_step(iterator) except errors.OutOfRangeError: context.async_clear_error() self.assertAllEqual(v.numpy(), 4.0)