def test_serialize_tensorflow_with_dataset_not_optimized(self): @tf.function def test_foo(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) def legacy_dataset_reducer_example(ds): return test_foo(ds) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) self.assertGraphDoesNotContainOps(graph_def, ['OptimizeDataset', 'ModelDataste']) results = tf.compat.v1.Session().run( tf.import_graph_def( graph_def, { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def test_serialize_tensorflow_with_table_no_variables(self): def table_lookup(word): table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'], np.arange(3, dtype=np.int64)), num_oov_buckets=1) return table.lookup(word) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( table_lookup, computation_types.TensorType(dtype=tf.string, shape=(None, )), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(string[?] -> int64[?])') self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') with tf.Graph().as_default() as g: tf.import_graph_def(serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), name='') with tf.compat.v1.Session(graph=g) as sess: sess.run(fetches=comp.tensorflow.initialize_op) results = sess.run( fetches=comp.tensorflow.result.tensor.tensor_name, feed_dict={ comp.tensorflow.parameter.tensor.tensor_name: ['b', 'c', 'a'] }) self.assertAllEqual(results, [1, 2, 0])
def test_returns_string_for_comp_with_left_overhang(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = computation_building_blocks.Reference('a', fn_type) proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation( proto, 'bbbbb') arg = computation_building_blocks.Call(compiled) comp = computation_building_blocks.Call(fn, arg) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'a(comp#bbbbb())') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'a(comp#bbbbb())') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' ' Ref(a) Call\n' ' /\n' 'Compiled(bbbbb)')
def test_serialize_tensorflow_with_structured_type_signature(self): batch_type = collections.namedtuple('BatchType', ['x', 'y']) output_type = collections.namedtuple('OutputType', ['A', 'B']) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y), batch_type(tf.int32, (tf.float32, [2])), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') self.assertEqual( str(extra_type_spec), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertIsInstance( extra_type_spec.parameter, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.parameter), batch_type) self.assertIsInstance( extra_type_spec.result, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.result), output_type)
def test_serialize_tensorflow_with_no_parameter(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '( -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') results = tf.Session().run( tf.import_graph_def(comp.tensorflow.graph_def, None, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [99])
def test_returns_string_for_compiled_computation(self): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) comp = computation_building_blocks.CompiledComputation(proto, 'a') compact_string = comp.compact_representation() self.assertEqual(compact_string, 'comp#a') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'comp#a') structural_string = comp.structural_representation() self.assertEqual(structural_string, 'Compiled(a)')
def test_replace_compiled_computations_names_replaces_name(self): fn = lambda: tf.constant(1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, None, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation( tf_comp) comp = compiled_comp transformed_comp = transformations.replace_compiled_computations_names_with_unique_names( comp) self.assertNotEqual(transformed_comp._name, comp._name)
def test_serialize_tensorflow_with_simple_add_three_lambda(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.constant(1000) results = tf.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), {comp.tensorflow.parameter.tensor.tensor_name: parameter}, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [1003])
def test_deserialize_and_call_tf_computation_with_add_one(self): ctx_stack = context_stack_impl.context_stack add_one, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: tf.add(x, 1, name='the_add'), tf.int32, ctx_stack) init_op, result = ( tensorflow_deserialization.deserialize_and_call_tf_computation( add_one, tf.constant(10, name='the_ten'), tf.get_default_graph())) self.assertTrue(tf.contrib.framework.is_tensor(result)) with tf.Session() as sess: if init_op: sess.run(init_op) result_val = sess.run(result) self.assertEqual(result_val, 11)
def _tf_wrapper_fn(target_fn, parameter_type, name=None): """Wrapper function to plug Tensorflow logic in to TFF framework.""" del name if not type_utils.check_tf_comp_whitelisted(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `NamedTupleType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack)
def test_fetch_value_with_nested_datasets(self): def return_two_datasets(): return [tf.data.Dataset.range(5), tf.data.Dataset.range(5)] executable_return_two_datasets = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_two_datasets, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_two_datasets() self.assertEqual([i for i in iter(x[0])], list(range(5))) self.assertEqual([i for i in iter(x[1])], list(range(5)))
def test_fetch_value_with_empty_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]]) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertEqual(x[0][0], 0.) self.assertEqual(x[0][1], 0.) self.assertEqual(str(x[1][0]), str(np.zeros([0, 2], dtype=np.int32)))
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic in to TFF framework.""" del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_utils.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `NamedTupleType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def test_fetch_value_with_dataset_and_tensor(self): def return_dataset_and_tensor(): return [tf.constant(0), tf.data.Dataset.range(5), tf.constant(5)] executable_return_dataset_and_tensor = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset_and_tensor, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset_and_tensor() self.assertEqual(x[0], 0) self.assertEqual([x for x in iter(x[1])], list(range(5))) self.assertEqual(x[2], 5)
def test_basic_functionality_of_compiled_computation_class(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) x = computation_building_blocks.CompiledComputation(comp) self.assertEqual(str(x.type_signature), '(int32 -> int32)') self.assertEqual(str(x.proto), str(comp)) self.assertTrue( re.match( r'CompiledComputation\([0-9a-f]+, ' r'FunctionType\(TensorType\(tf\.int32\), ' r'TensorType\(tf\.int32\)\)\)', repr(x))) self.assertTrue(re.match(r'comp#[0-9a-f]+', x.tff_repr)) y = computation_building_blocks.CompiledComputation(comp, name='foo') self.assertEqual(y.tff_repr, 'comp#foo') self._serialize_deserialize_roundtrip_test(x)
def test_fetch_value_with_empty_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]]) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertAllEqual(x[0], [0., 0.]) self.assertEqual(x[1].element_spec, tf.TensorSpec(shape=(None, 2), dtype=tf.int32)) with self.assertRaises(StopIteration): _ = next(iter(x[1]))
def test_fetch_value_with_datasets_nested_at_second_level(self): def return_two_datasets(): return [ tf.constant(0), [tf.data.Dataset.range(5), tf.data.Dataset.range(5)] ] executable_return_two_datasets = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_two_datasets, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_two_datasets() self.assertEqual(x[0], 0) self.assertEqual(x[1][0], list(range(5))) self.assertEqual(x[1][1], list(range(5)))
def _wrap_constant_as_value(const, context_stack): """Wraps the given Python constant as a `tff.Value`. Args: const: Python constant to be converted to TFF value. Anything convertible to Tensor via `tf.constant` can be passed in. context_stack: The context stack to use. Returns: An instance of `value_base.Value`. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(const), None, context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) called_comp = computation_building_blocks.Call(compiled_comp) return ValueImpl(called_comp, context_stack)
def test_returns_string_for_call_with_no_arg(self): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(1), None, context_stack_impl.context_stack) compiled = computation_building_blocks.CompiledComputation(proto, 'a') comp = computation_building_blocks.Call(compiled) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'comp#a()') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'comp#a()') structural_string = computation_building_blocks.structural_representation( comp) # pyformat: disable self.assertEqual(structural_string, ' Call\n' ' /\n' 'Compiled(a)')
def test_fetch_value_with_empty_structured_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])])) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertEqual(x[0][0], 0.) self.assertEqual(x[0][1], 0.) self.assertTrue( np.array_equal(x[1][0].a, np.zeros([0], dtype=np.int32))) self.assertTrue( np.array_equal(x[1][0].b, np.zeros([0], dtype=np.int32)))
def _create_call_to_py_fn(fn): r"""Creates a computation to call a Python function. Call / Compiled Computation Args: fn: The Python function to wrap. Returns: An instance of `computation_building_blocks.Call` wrapping the Python function. """ tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, None, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) return computation_building_blocks.Call(compiled_comp)
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType( tf.int64), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.Session().run( tf.import_graph_def( comp.tensorflow.graph_def, { comp.tensorflow.parameter.sequence.iterator_string_handle_name: (parameter.make_one_shot_iterator().string_handle()) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def test_fetch_value_with_empty_structured_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])])) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertAllEqual(x[0], [0., 0.]) self.assertEqual( tf.data.experimental.get_structure(x[1]), collections.OrderedDict([ ('a', tf.TensorSpec(shape=(None, ), dtype=tf.int32)), ('b', tf.TensorSpec(shape=(None, ), dtype=tf.int32)), ])) with self.assertRaises(StopIteration): _ = next(iter(x[1]))
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def _wrap_sequence_as_value(elements, element_type, context_stack): """Wraps `elements` as a TFF sequence with elements of type `element_type`. Args: elements: Python object to the wrapped as a TFF sequence value. element_type: An instance of `Type` that determines the type of elements of the sequence. context_stack: The context stack to use. Returns: An instance of `tff.Value`. Raises: TypeError: If `elements` and `element_type` are of incompatible types. """ # TODO(b/113116813): Add support for other representations of sequences. py_typecheck.check_type(elements, list) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) # Checks that the types of all the individual elements are compatible with the # requested type of the sequence as a while. for elem in elements: elem_type = type_utils.infer_type(elem) if not type_utils.is_assignable_from(element_type, elem_type): raise TypeError( 'Expected all sequence elements to be {}, found {}.'.format( str(element_type), str(elem_type))) # Defines a no-arg function that builds a `tf.data.Dataset` from the elements. def _create_dataset_from_elements(): return graph_utils.make_data_set_from_elements(tf.get_default_graph(), elements, element_type) # Wraps the dataset as a value backed by a no-argument TensorFlow computation. return ValueImpl( computation_building_blocks.Call( computation_building_blocks.CompiledComputation( tensorflow_serialization.serialize_py_fn_as_tf_computation( _create_dataset_from_elements, None, context_stack))), context_stack)
def _create_lambda_to_cast(dtype1, dtype2): r"""Creates a computation to TensorFlow cast from dtype1 to dtype2. Lambda \ Call / \ Compiled Reference Computation Where `CompiledComputation` is a TensorFlow computation casting from `dtype1` to `dtype2`. The `dtype` arguments can be either instances of `tf.dtypes.DType` or `computation_types.TensorType`, but in the latter case the `tf.dtypes.DType` of these tensors will be extracted. Args: dtype1: The type of the argument. dtype2: The type to cast the argument to. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that casts TensorFlow dtype1 to dtype2. """ if isinstance(dtype1, computation_types.TensorType): dtype1 = dtype1.dtype if isinstance(dtype2, computation_types.TensorType): dtype2 = dtype2.dtype py_typecheck.check_type(dtype1, tf.dtypes.DType) py_typecheck.check_type(dtype2, tf.dtypes.DType) arg = computation_building_blocks.Reference('arg', dtype1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: tf.cast(x, dtype2), dtype1, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation(tf_comp) call = computation_building_blocks.Call(compiled_comp, arg) return computation_building_blocks.Lambda(arg.name, dtype1, call)
def test_replace_compiled_computations_names_replaces_multiple_names(self): comps = [] for _ in range(10): fn = lambda: tf.constant(1) tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( fn, None, context_stack_impl.context_stack) compiled_comp = computation_building_blocks.CompiledComputation( tf_comp) comps.append(compiled_comp) tup = computation_building_blocks.Tuple(comps) comp = tup transformed_comp = transformations.replace_compiled_computations_names_with_unique_names( comp) comp_names = [element._name for element in comp] transformed_comp_names = [ element._name for element in transformed_comp ] self.assertNotEqual(transformed_comp_names, comp_names) self.assertEqual( len(transformed_comp_names), len(set(transformed_comp_names)), 'The transformed computation names are not unique: {}.'.format( transformed_comp_names))
def _create_compiled_computation(py_fn, arg_type): proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( py_fn, arg_type, context_stack_impl.context_stack) return building_blocks.CompiledComputation(proto)