def test_serialize_tensorflow_with_dataset_not_optimized(self): @tf.function def test_foo(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) def legacy_dataset_reducer_example(ds): return test_foo(ds) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) self.assertGraphDoesNotContainOps( graph_def, ['OptimizeDataset', 'OptimizeDatasetV2', 'ModelDataset']) results = tf.compat.v1.Session().run( tf.import_graph_def( graph_def, { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def test_serialize_tensorflow_with_table_no_variables(self): def table_lookup(word): table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'], np.arange(3, dtype=np.int64)), num_oov_buckets=1) return table.lookup(word) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( table_lookup, computation_types.TensorType(dtype=tf.string, shape=(None, )), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(string[?] -> int64[?])') self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') with tf.Graph().as_default() as g: tf.import_graph_def(serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), name='') with tf.compat.v1.Session(graph=g) as sess: sess.run(fetches=comp.tensorflow.initialize_op) results = sess.run( fetches=comp.tensorflow.result.tensor.tensor_name, feed_dict={ comp.tensorflow.parameter.tensor.tensor_name: ['b', 'c', 'a'] }) self.assertAllEqual(results, [1, 2, 0])
def test_serialize_tensorflow_with_no_parameter(self): comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '( -> int32)') self.assertEqual(str(extra_type_spec), '( -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), None, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [99])
def test_serialize_tensorflow_with_simple_add_three_lambda(self): comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)') self.assertEqual(str(extra_type_spec), '(int32 -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.constant(1000) results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), {comp.tensorflow.parameter.tensor.tensor_name: parameter}, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [1003])
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def test_serialize_tensorflow_with_structured_type_signature(self): batch_type = collections.namedtuple('BatchType', ['x', 'y']) output_type = collections.namedtuple('OutputType', ['A', 'B']) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y), batch_type(tf.int32, (tf.float32, [2])), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') self.assertEqual( str(extra_type_spec), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertIsInstance(extra_type_spec.parameter, computation_types.StructWithPythonType) self.assertIs(extra_type_spec.parameter.python_container, batch_type) self.assertIsInstance(extra_type_spec.result, computation_types.StructWithPythonType) self.assertIs(extra_type_spec.result.python_container, output_type)
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])