Beispiel #1
0
 def _run_comp(self, comp_pb, comp_type, arg=None):
     self.assertIsInstance(comp_pb, pb.Computation)
     self.assertIsInstance(comp_type, computation_types.FunctionType)
     backend = xla_client.get_local_backend(None)
     comp_callable = runtime.ComputationCallable(comp_pb, comp_type,
                                                 backend)
     arg_list = []
     if arg is not None:
         arg_list.append(arg)
     return comp_callable(*arg_list)
 def test_computation_callable_return_one_number(self):
   builder = xla_client.XlaBuilder('comp')
   xla_client.ops.Parameter(builder, 0, xla_client.shape_from_pyval(tuple()))
   xla_client.ops.Constant(builder, np.int32(10))
   xla_comp = builder.build()
   comp_type = computation_types.FunctionType(None, np.int32)
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [], comp_type)
   backend = jax.lib.xla_bridge.get_backend()
   comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend)
   self.assertIsInstance(comp_callable, runtime.ComputationCallable)
   self.assertEqual(str(comp_callable.type_signature), '( -> int32)')
   result = comp_callable()
   self.assertEqual(result, 10)
Beispiel #3
0
def to_representation_for_type(value, type_spec, backend=None):
    """Verifies or converts the `value` to executor payload matching `type_spec`.

  The following kinds of `value` are supported:

  * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`.

  * Numpy arrays and scalars, or Python scalars that are converted to Numpy.

  * Nested structures of the above.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`. Can be `None` for values that derive
      from `typed_object.TypedObject`.
    backend: The backend to use; an instance of `xla_client.Client`. Only used
      for functional types. Can be `None` if unused.

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
    ValueError: If the arguments are incorrect.
  """
    if backend is not None:
        py_typecheck.check_type(backend, xla_client.Client)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
    type_spec = executor_utils.reconcile_value_with_type_spec(value, type_spec)
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value), type_spec,
            backend)
    if isinstance(value, pb.Computation):
        comp_type = type_serialization.deserialize_type(value.type)
        if type_spec is not None:
            comp_type.check_equivalent_to(type_spec)
        return runtime.ComputationCallable(value, comp_type, backend)
    if isinstance(type_spec, computation_types.StructType):
        return structure.map_structure(
            lambda v, t: to_representation_for_type(v, t, backend),
            structure.from_container(value, recursive=True), type_spec)
    if isinstance(type_spec, computation_types.TensorType):
        return runtime.normalize_tensor_representation(value, type_spec)
    raise TypeError('Unexpected type {}.'.format(type_spec))
 def test_computation_callable_add_two_numbers(self):
   builder = xla_client.XlaBuilder('comp')
   param = xla_client.ops.Parameter(
       builder, 0,
       xla_client.shape_from_pyval(tuple([np.array(0, dtype=np.int32)] * 2)))
   xla_client.ops.Add(
       xla_client.ops.GetTupleElement(param, 0),
       xla_client.ops.GetTupleElement(param, 1))
   xla_comp = builder.build()
   comp_type = computation_types.FunctionType((np.int32, np.int32), np.int32)
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [0, 1], comp_type)
   backend = jax.lib.xla_bridge.get_backend()
   comp_callable = runtime.ComputationCallable(comp_pb, comp_type, backend)
   self.assertIsInstance(comp_callable, runtime.ComputationCallable)
   self.assertEqual(
       str(comp_callable.type_signature), '(<int32,int32> -> int32)')
   result = comp_callable(
       structure.Struct([(None, np.int32(2)), (None, np.int32(3))]))
   self.assertEqual(result, 5)