def __init__(self, backend=None): """Creates a new instance of an XLA executor. Args: backend: An optional name of a local XLA backend. """ self._backend = xla_client.get_local_backend(backend)
def __init__(self, device=None): """Creates a new instance of an XLA executor. Args: device: An optional device name (currently unsupported; must be `None`). """ if device is not None: raise ValueError( 'Explicitly specifying a device is currently not supported.') self._backend = xla_client.get_local_backend(None)
def _run_comp(self, comp_pb, comp_type, arg=None): self.assertIsInstance(comp_pb, pb.Computation) self.assertIsInstance(comp_type, computation_types.FunctionType) backend = xla_client.get_local_backend(None) comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) arg_list = [] if arg is not None: arg_list.append(arg) return comp_callable(*arg_list)
def test_computation_callable_return_one_number(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) backend = xla_client.get_local_backend(None) comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual(str(comp_callable.type_signature), '( -> int32)') result = comp_callable() self.assertEqual(result, 10)
def test_computation_callable_add_two_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval( tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add(xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) backend = xla_client.get_local_backend(None) comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual(str(comp_callable.type_signature), '(<int32,int32> -> int32)') result = comp_callable( structure.Struct([(None, np.int32(2)), (None, np.int32(3))])) self.assertEqual(result, 5)
def setUp(self): super(ExecutorTest, self).setUp() self._backend = xla_client.get_local_backend()