def _run_comp(self, comp_pb, comp_type, arg=None): self.assertIsInstance(comp_pb, pb.Computation) self.assertIsInstance(comp_type, computation_types.FunctionType) backend = xla_client.get_local_backend(None) comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) arg_list = [] if arg is not None: arg_list.append(arg) return comp_callable(*arg_list)
def test_computation_callable_return_one_number(self): builder = xla_client.XlaBuilder('comp') xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple())) xla_client.ops.Constant(builder, np.int32(10)) xla_comp = builder.build() comp_type = computation_types.FunctionType(None, np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [], comp_type) backend = jax.lib.xla_bridge.get_backend() comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual(str(comp_callable.type_signature), '( -> int32)') result = comp_callable() self.assertEqual(result, 10)
def to_representation_for_type(value, type_spec, backend=None): """Verifies or converts the `value` to executor payload matching `type_spec`. The following kinds of `value` are supported: * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`. * Numpy arrays and scalars, or Python scalars that are converted to Numpy. * Nested structures of the above. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted. type_spec: An instance of `tff.Type`. Can be `None` for values that derive from `typed_object.TypedObject`. backend: The backend to use; an instance of `xla_client.Client`. Only used for functional types. Can be `None` if unused. Returns: Either `value` itself, or a modified version of it. Raises: TypeError: If the `value` is not compatible with `type_spec`. ValueError: If the arguments are incorrect. """ if backend is not None: py_typecheck.check_type(backend, xla_client.Client) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_spec = executor_utils.reconcile_value_with_type_spec(value, type_spec) if isinstance(value, computation_base.Computation): return to_representation_for_type( computation_impl.ComputationImpl.get_proto(value), type_spec, backend) if isinstance(value, pb.Computation): comp_type = type_serialization.deserialize_type(value.type) if type_spec is not None: comp_type.check_equivalent_to(type_spec) return runtime.ComputationCallable(value, comp_type, backend) if isinstance(type_spec, computation_types.StructType): return structure.map_structure( lambda v, t: to_representation_for_type(v, t, backend), structure.from_container(value, recursive=True), type_spec) if isinstance(type_spec, computation_types.TensorType): return runtime.normalize_tensor_representation(value, type_spec) raise TypeError('Unexpected type {}.'.format(type_spec))
def test_computation_callable_add_two_numbers(self): builder = xla_client.XlaBuilder('comp') param = xla_client.ops.Parameter( builder, 0, xla_client.shape_from_pyval(tuple([np.array(0, dtype=np.int32)] * 2))) xla_client.ops.Add( xla_client.ops.GetTupleElement(param, 0), xla_client.ops.GetTupleElement(param, 1)) xla_comp = builder.build() comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32) comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0, 1], comp_type) backend = jax.lib.xla_bridge.get_backend() comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend) self.assertIsInstance(comp_callable, runtime.ComputationCallable) self.assertEqual( str(comp_callable.type_signature), '(<int32,int32> -> int32)') result = comp_callable( structure.Struct([(None, np.int32(2)), (None, np.int32(3))])) self.assertEqual(result, 5)