def test_basic_functionality_of_tuple_class(self): x = building_blocks.Reference('foo', tf.int32) y = building_blocks.Reference('bar', tf.bool) z = building_blocks.Tuple([x, ('y', y)]) with self.assertRaises(ValueError): _ = building_blocks.Tuple([('', y)]) self.assertIsInstance(z, anonymous_tuple.AnonymousTuple) self.assertEqual(str(z.type_signature), '<int32,y=bool>') self.assertEqual( repr(z), 'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', ' 'Reference(\'bar\', TensorType(tf.bool)))])') self.assertEqual(z.compact_representation(), '<foo,y=bar>') self.assertEqual(dir(z), ['y']) self.assertIs(z.y, y) self.assertLen(z, 2) self.assertIs(z[0], x) self.assertIs(z[1], y) self.assertEqual(','.join(e.compact_representation() for e in iter(z)), 'foo,bar') z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'tuple') self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y']) self._serialize_deserialize_roundtrip_test(z)
def test_returns_coputation(self): proto = computation_factory.create_lambda_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) self.assertEqual(actual_type, expected_type)
def test_basic_functionality_of_block_class(self): x = building_blocks.Block([ ('x', building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', building_blocks.Selection( building_blocks.Reference('x', (tf.int32, tf.int32)), index=0)) ], building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual([(k, v.compact_representation()) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual(x.result.compact_representation(), 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), ' '(\'y\', Selection(Reference(\'x\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_intrinsic_class_succeeds_simple_federated_map(self): simple_function = computation_types.FunctionType(tf.int32, tf.float32) federated_arg = computation_types.FederatedType( simple_function.parameter, placement_literals.CLIENTS) federated_result = computation_types.FederatedType( simple_function.result, placement_literals.CLIENTS) federated_map_concrete_type = computation_types.FunctionType( [simple_function, federated_arg], federated_result) concrete_federated_map = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type) self.assertIsInstance(concrete_federated_map, building_blocks.Intrinsic) self.assertEqual( str(concrete_federated_map.type_signature), '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)') self.assertEqual(concrete_federated_map.uri, 'federated_map') self.assertEqual(concrete_federated_map.compact_representation(), 'federated_map') concrete_federated_map_proto = concrete_federated_map.proto self.assertEqual( type_serialization.deserialize_type( concrete_federated_map_proto.type), concrete_federated_map.type_signature) self.assertEqual( concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri) self._serialize_deserialize_roundtrip_test(concrete_federated_map)
def test_serialize_tensorflow_with_table_no_variables(self): def table_lookup(word): table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'], np.arange(3, dtype=np.int64)), num_oov_buckets=1) return table.lookup(word) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( table_lookup, computation_types.TensorType(dtype=tf.string, shape=(None,)), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(string[?] -> int64[?])') self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') with tf.Graph().as_default() as g: tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), name='') with tf.compat.v1.Session(graph=g) as sess: sess.run(fetches=comp.tensorflow.initialize_op) results = sess.run( fetches=comp.tensorflow.result.tensor.tensor_name, feed_dict={ comp.tensorflow.parameter.tensor.tensor_name: ['b', 'c', 'a'] }) self.assertAllEqual(results, [1, 2, 0])
def test_basic_functionality_of_call_class(self): x = building_blocks.Reference( 'foo', computation_types.FunctionType(tf.int32, tf.bool)) y = building_blocks.Reference('bar', tf.int32) z = building_blocks.Call(x, y) self.assertEqual(str(z.type_signature), 'bool') self.assertIs(z.function, x) self.assertIs(z.argument, y) self.assertEqual( repr(z), 'Call(Reference(\'foo\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), ' 'Reference(\'bar\', TensorType(tf.int32)))') self.assertEqual(z.compact_representation(), 'foo(bar)') with self.assertRaises(TypeError): building_blocks.Call(x) w = building_blocks.Reference('bak', tf.float32) with self.assertRaises(TypeError): building_blocks.Call(x, w) z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'call') self.assertEqual(str(z_proto.call.function), str(x.proto)) self.assertEqual(str(z_proto.call.argument), str(y.proto)) self._serialize_deserialize_roundtrip_test(z)
def run_tensorflow(computation_proto, arg=None): """Runs a TensorFlow computation with argument `arg`. Args: computation_proto: An instance of `pb.Computation`. arg: The argument to invoke the computation with, or None if the computation does not specify a parameter type and does not expects one. Returns: The result of the computation. """ with tf.Graph().as_default() as graph: type_signature = type_serialization.deserialize_type( computation_proto.type) if type_signature.parameter is not None: stamped_arg = _stamp_value_into_graph(arg, type_signature.parameter, graph) else: stamped_arg = None init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation( computation_proto, stamped_arg, graph) with tf.compat.v1.Session(graph=graph) as sess: if init_op: sess.run(init_op) result = tensorflow_utils.fetch_value_in_session(sess, result) return result
def test_serialize_tensorflow_with_dataset_not_optimized(self): @tf.function def test_foo(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) def legacy_dataset_reducer_example(ds): return test_foo(ds) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) self.assertGraphDoesNotContainOps( graph_def, ['OptimizeDataset', 'OptimizeDatasetV2', 'ModelDataset']) results = tf.compat.v1.Session().run( tf.import_graph_def( graph_def, { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = building_blocks.Reference(arg_name, arg_type) arg_f = building_blocks.Selection(arg, name='f') arg_x = building_blocks.Selection(arg, name='x') x = building_blocks.Lambda( arg_name, arg_type, building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x))) self.assertEqual( str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.compact_representation(), 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.compact_representation(), '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual( str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_serialize_jax_with_int32_to_int32(self): self.skipTest('HLO pattern matching broken by ' 'https://github.com/google/jax/pull/10232') def traced_fn(x): return x + 10 param_type = computation_types.to_type(np.int32) arg_fn = function_utils.create_argument_unpacking_fn( traced_fn, param_type) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( traced_fn, arg_fn, param_type, ctx_stack) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32 -> int32)') xla_comp = xla_serialization.unpack_xla_computation( comp_pb.xla.hlo_module) self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)', xla_comp.as_hlo_text()) self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter)) self.assertEqual(str(comp_pb.xla.result), 'tensor {\n' ' index: 0\n' '}\n')
def __init__(self, computation_proto, context_stack, annotated_type=None): """Constructs a new instance of ComputationImpl from the computation_proto. Args: computation_proto: The protocol buffer that represents the computation, an instance of pb.Computation. context_stack: The context stack to use. annotated_type: Optional, type information with additional annotations that replaces the information in `computation_proto.type`. Raises: TypeError: if `annotated_type` is not `None` and is not compatible with `computation_proto.type`. """ py_typecheck.check_type(computation_proto, pb.Computation) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) type_spec = type_serialization.deserialize_type(computation_proto.type) py_typecheck.check_type(type_spec, computation_types.Type) if annotated_type is not None: if not type_spec.is_assignable_from(annotated_type): raise TypeError( 'annotated_type not compatible with computation_proto.type\n' 'computation_proto.type: {!s}\n' 'annotated_type: {!s}'.format(type_spec, annotated_type)) type_spec = annotated_type type_analysis.check_well_formed(type_spec) if not type_spec.is_function(): raise TypeError( '{} is not a functional type, from proto: {}'.format( str(type_spec), str(computation_proto))) super().__init__(type_spec, context_stack) self._computation_proto = computation_proto
def deserialize_sequence_value(sequence_value_proto): """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': ds = tensorflow_serialization.deserialize_dataset( sequence_value_proto.zipped_saved_model) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) return ds, computation_types.SequenceType(element=element_type)
def _serialize_computation( comp: computation_pb2.Computation, type_spec: Optional[computation_types.Type]) -> _SerializeReturnType: """Serializes a TFF computation.""" type_spec = executor_utils.reconcile_value_type_with_type_spec( type_serialization.deserialize_type(comp.type), type_spec) return serialization_bindings.Value(computation=comp), type_spec
def test_serialize_tensorflow_with_structured_type_signature(self): batch_type = collections.namedtuple('BatchType', ['x', 'y']) output_type = collections.namedtuple('OutputType', ['A', 'B']) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y), batch_type(tf.int32, (tf.float32, [2])), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') self.assertEqual( str(extra_type_spec), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertIsInstance( extra_type_spec.parameter, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.parameter), batch_type) self.assertIsInstance( extra_type_spec.result, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.result), output_type)
def _deserialize_sequence_value( sequence_value_proto: executor_pb2.Value.Sequence ) -> _DeserializeReturnType: """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': warnings.warn( 'Deserializng a sequence value that was encoded as a zipped SavedModel.' ' This is a deprecated path, please update the binary that is ' 'serializing the sequences.', DeprecationWarning) ds = _deserialize_dataset_from_zipped_saved_model( sequence_value_proto.zipped_saved_model) ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) elif which_value == 'serialized_graph_def': ds = _deserialize_dataset_from_graph_def( sequence_value_proto.serialized_graph_def, element_type) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) return ds, computation_types.SequenceType(element=element_type)
async def _evaluate_lambda( self, comp: pb.Computation, scope: ReferenceResolvingExecutorScope, ) -> ReferenceResolvingExecutorValue: type_spec = type_serialization.deserialize_type(comp.type) return ReferenceResolvingExecutorValue( ScopedLambda(comp, scope), type_spec=type_spec)
async def _evaluate_to_delegate( self, comp: pb.Computation, scope: ReferenceResolvingExecutorScope, ) -> ReferenceResolvingExecutorValue: return ReferenceResolvingExecutorValue( await self._target_executor.create_value( comp, type_serialization.deserialize_type(comp.type)))
def test_returns_computation_sequence(self): type_signature = computation_types.SequenceType(tf.int32) proto = computation_factory.create_lambda_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type)
def test_returns_computation_tuple_unnamed(self): type_signature = computation_types.StructType([tf.int32, tf.float32]) proto = computation_factory.create_lambda_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type)
def _deserialize_type_spec(serialize_type_variable, python_container=None): """Deserialize a `tff.Type` protocol buffer into a python class instance.""" type_spec = type_serialization.deserialize_type( computation_pb2.Type.FromString( serialize_type_variable.read_value().numpy())) if type_spec.is_struct() and python_container is not None: type_spec = computation_types.StructWithPythonType( structure.iter_elements(type_spec), python_container) return type_conversions.type_to_tf_structure(type_spec)
def test_create_xla_tff_computation_int32x10_to_int32x10(self): xla_comp = _make_test_xla_comp_int32x10_to_int32x10() comp_pb = xla_serialization.create_xla_tff_computation( xla_comp, [0], computation_types.FunctionType((np.int32, (10,)), (np.int32, (10,)))) self.assertIsInstance(comp_pb, pb.Computation) self.assertEqual(comp_pb.WhichOneof('computation'), 'xla') type_spec = type_serialization.deserialize_type(comp_pb.type) self.assertEqual(str(type_spec), '(int32[10] -> int32[10])')
def test_returns_computation(self, type_signature, value): proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) self.assertEqual(actual_result, value)
def test_returns_computation(self): proto, _ = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto) expected_result = structure.Struct([]) self.assertEqual(actual_result, expected_result)
def test_returns_coputation(self): proto = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto) expected_result = anonymous_tuple.AnonymousTuple([]) self.assertEqual(actual_result, expected_result)
def test_serialize_deserialize_named_tuple_types_py_container(self): # The Py container is destroyed during ser/de. with_container = computation_types.StructWithPythonType( (tf.int32, tf.bool), tuple) p1 = type_serialization.serialize_type(with_container) without_container = type_serialization.deserialize_type(p1) self.assertNotEqual(with_container, without_container) # Not equal. self.assertIsInstance(without_container, computation_types.StructType) self.assertNotIsInstance(without_container, computation_types.StructWithPythonType) with_container.check_equivalent_to(without_container)
def test_basic_functionality_of_placement_class(self): x = building_blocks.Placement(placement_literals.CLIENTS) self.assertEqual(str(x.type_signature), 'placement') self.assertEqual(x.uri, 'clients') self.assertEqual(repr(x), 'Placement(\'clients\')') self.assertEqual(x.compact_representation(), 'CLIENTS') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'placement') self.assertEqual(x_proto.placement.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_reference_class(self): x = building_blocks.Reference('foo', tf.int32) self.assertEqual(x.name, 'foo') self.assertEqual(str(x.type_signature), 'int32') self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))') self.assertEqual(x.compact_representation(), 'foo') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'reference') self.assertEqual(x_proto.reference.name, x.name) self._serialize_deserialize_roundtrip_test(x)
def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_serialize_tensorflow_with_no_parameter(self): comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '( -> int32)') self.assertEqual(str(extra_type_spec), '( -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), None, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [99])
def to_representation_for_type(value, type_spec, backend=None): """Verifies or converts the `value` to executor payload matching `type_spec`. The following kinds of `value` are supported: * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`. * Numpy arrays and scalars, or Python scalars that are converted to Numpy. * Nested structures of the above. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted. type_spec: An instance of `tff.Type`. Can be `None` for values that derive from `typed_object.TypedObject`. backend: The backend to use; an instance of `xla_client.Client`. Only used for functional types. Can be `None` if unused. Returns: Either `value` itself, or a modified version of it. Raises: TypeError: If the `value` is not compatible with `type_spec`. ValueError: If the arguments are incorrect. """ if backend is not None: py_typecheck.check_type(backend, xla_client.Client) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec) if isinstance(value, computation_base.Computation): return to_representation_for_type( computation_impl.ComputationImpl.get_proto(value), type_spec, backend) if isinstance(value, pb.Computation): comp_type = type_serialization.deserialize_type(value.type) if type_spec is not None: comp_type.check_equivalent_to(type_spec) return _ComputationCallable(value, comp_type, backend) if isinstance(type_spec, computation_types.StructType): return structure.map_structure( lambda v, t: to_representation_for_type(v, t, backend), structure.from_container(value, recursive=True), type_spec) if isinstance(type_spec, computation_types.TensorType): type_spec.shape.assert_is_fully_defined() type_analysis.check_type(value, type_spec) if type_spec.shape.rank == 0: return np.dtype(type_spec.dtype.as_numpy_dtype).type(value) if type_spec.shape.rank > 0: return np.array(value, dtype=type_spec.dtype.as_numpy_dtype) raise TypeError('Unsupported tensor shape {}.'.format(type_spec.shape)) raise TypeError('Unexpected type {}.'.format(type_spec))