def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) pivot = control_flow_ops.no_op() context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() self.assertFalse(control_flow_util.IsInXLAContext(z1.op)) self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
def f(): pivot = control_flow_ops.no_op() context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() array_ops.identity(z) # Capture z. z1 = array_ops.zeros([3, 2], name="a") assert z1.name == "a:0", "Expected: a:0, Found: %s" % z1.name z2 = array_ops.zeros([3, 2], name="a") # Prior to fixing b/166794533 this would fail with a shape mismatch # because context.AddValue would have cached `z` by its name which # collides with z1's name. result = z1 + z2 context.Exit() return result