def test_normalize_tensor_representation_int32x2x3(self):
   result = runtime.normalize_tensor_representation(
       np.array(((1, 2), (3, 4), (5, 6)), dtype=np.int32),
       computation_types.TensorType(np.int32, (3, 2)))
   self.assertIsInstance(result, np.ndarray)
   self.assertEqual(result.dtype, np.int32)
   self.assertEqual(result.shape, (3, 2))
   self.assertEqual(list(result.flatten()), [1, 2, 3, 4, 5, 6])
Example #2
0
def to_representation_for_type(value, type_spec, backend=None):
    """Verifies or converts the `value` to executor payload matching `type_spec`.

  The following kinds of `value` are supported:

  * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`.

  * Numpy arrays and scalars, or Python scalars that are converted to Numpy.

  * Nested structures of the above.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`. Can be `None` for values that derive
      from `typed_object.TypedObject`.
    backend: The backend to use; an instance of `xla_client.Client`. Only used
      for functional types. Can be `None` if unused.

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
    ValueError: If the arguments are incorrect.
  """
    if backend is not None:
        py_typecheck.check_type(backend, xla_client.Client)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
    type_spec = executor_utils.reconcile_value_with_type_spec(value, type_spec)
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value), type_spec,
            backend)
    if isinstance(value, pb.Computation):
        comp_type = type_serialization.deserialize_type(value.type)
        if type_spec is not None:
            comp_type.check_equivalent_to(type_spec)
        return runtime.ComputationCallable(value, comp_type, backend)
    if isinstance(type_spec, computation_types.StructType):
        return structure.map_structure(
            lambda v, t: to_representation_for_type(v, t, backend),
            structure.from_container(value, recursive=True), type_spec)
    if isinstance(type_spec, computation_types.TensorType):
        return runtime.normalize_tensor_representation(value, type_spec)
    raise TypeError('Unexpected type {}.'.format(type_spec))
 def test_normalize_tensor_representation_int32(self):
   result = runtime.normalize_tensor_representation(
       10, computation_types.TensorType(np.int32))
   self.assertIsInstance(result, np.int32)
   self.assertEqual(result, 10)