コード例 #1
0
    def test_error_bad_consumer_id(self):
        """Try to use reserved consumer ID 0.

    Check that we get the proper error from the runtime."""
        comp = xla_bridge.make_computation_builder(self._testMethodName)
        token = hcb.xops.CreateToken(comp)
        hcb._initialize_outfeed_receiver()  # Needed if this is the sole test
        with self.assertRaisesRegex(
                RuntimeError, "Consumer ID cannot be a reserved value: 0"):
            hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [
                xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))
            ])
コード例 #2
0
 def test_error_different_shapes(self):
     """Try to register different shapes for the same consumer ID."""
     comp = xla_bridge.make_computation_builder(self._testMethodName)
     token = hcb.xops.CreateToken(comp)
     hcb._initialize_outfeed_receiver()  # Needed if this is the sole test
     hcb._outfeed_receiver.receiver.add_outfeed(
         comp, token, 123,
         [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
     with self.assertRaisesRegex(
             RuntimeError,
             ".*does not match previous shape element_type.*"):
         hcb._outfeed_receiver.receiver.add_outfeed(
             comp, token, 123,
             [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
     with self.assertRaisesRegex(
             RuntimeError,
             ".*does not match previous shape element_type.*"):
         hcb._outfeed_receiver.receiver.add_outfeed(
             comp, token, 123,
             [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])