def test_error_bad_consumer_id(self): """Try to use reserved consumer ID 0. Check that we get the proper error from the runtime.""" comp = xla_bridge.make_computation_builder(self._testMethodName) token = hcb.xops.CreateToken(comp) hcb._initialize_outfeed_receiver() # Needed if this is the sole test with self.assertRaisesRegex( RuntimeError, "Consumer ID cannot be a reserved value: 0"): hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [ xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32)) ])
def test_error_different_shapes(self): """Try to register different shapes for the same consumer ID.""" comp = xla_bridge.make_computation_builder(self._testMethodName) token = hcb.xops.CreateToken(comp) hcb._initialize_outfeed_receiver() # Needed if this is the sole test hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])