def _compiled_comp_equal(comp_1, comp_2): """Returns `True` iff the computations are entirely identical. Args: comp_1: A `building_blocks.CompiledComputation` to test. comp_2: A `building_blocks.CompiledComputation` to test. Raises: TypeError: if `comp_1` or `comp_2` is not a `building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp_1, building_blocks.CompiledComputation) py_typecheck.check_type(comp_2, building_blocks.CompiledComputation) tensorflow_1 = comp_1.proto.tensorflow tensorflow_2 = comp_2.proto.tensorflow if tensorflow_1.initialize_op != tensorflow_2.initialize_op: return False if tensorflow_1.parameter != tensorflow_2.parameter: return False if tensorflow_1.result != tensorflow_2.result: return False graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def) graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def) # TODO(b/174605105): We prefer to mitigate nans comparing unequal for now, # given the severity of TFF's failure to handle this violation of its # assumption that trees_equal is an equivalence relation on its ASTs. But this # is not a long-term solution. To replace with more legitimate proto # comparison which treats nans as equal. return graphdef_1.SerializeToString( deterministic=True) == graphdef_2.SerializeToString(deterministic=True)
def _compiled_comp_equal(comp_1, comp_2): """Returns `True` iff the computations are entirely identical. Args: comp_1: A `building_blocks.CompiledComputation` to test. comp_2: A `building_blocks.CompiledComputation` to test. Raises: TypeError: if `comp_1` or `comp_2` is not a `building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp_1, building_blocks.CompiledComputation) py_typecheck.check_type(comp_2, building_blocks.CompiledComputation) tensorflow_1 = comp_1.proto.tensorflow tensorflow_2 = comp_2.proto.tensorflow if tensorflow_1.initialize_op != tensorflow_2.initialize_op: return False if tensorflow_1.parameter != tensorflow_2.parameter: return False if tensorflow_1.result != tensorflow_2.result: return False graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def) graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def) return graphdef_1 == graphdef_2
def _compiled_comp_equal(comp_1, comp_2): """Returns `True` iff the computations are entirely identical. Args: comp_1: A `building_blocks.CompiledComputation` to test. comp_2: A `building_blocks.CompiledComputation` to test. Raises: TypeError: if `comp_1` or `comp_2` is not a `building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp_1, building_blocks.CompiledComputation) py_typecheck.check_type(comp_2, building_blocks.CompiledComputation) tensorflow_1 = comp_1.proto.tensorflow tensorflow_2 = comp_2.proto.tensorflow if tensorflow_1.initialize_op != tensorflow_2.initialize_op: return False if tensorflow_1.parameter != tensorflow_2.parameter: return False if tensorflow_1.result != tensorflow_2.result: return False graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def) graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def) # TODO(b/174605105): Remove this gating when TFF updates its TensorFlow # dependency. if version_check.is_tensorflow_version_newer('2.6.0', tf): return tf.__internal__.graph_util.graph_defs_equal( graphdef_1, graphdef_2, treat_nan_as_equal=True) else: return graphdef_1.SerializeToString( deterministic=True) == graphdef_2.SerializeToString( deterministic=True)
def test_serialize_tensorflow_with_table_no_variables(self): def table_lookup(word): table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'], np.arange(3, dtype=np.int64)), num_oov_buckets=1) return table.lookup(word) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( table_lookup, computation_types.TensorType(dtype=tf.string, shape=(None, )), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(string[?] -> int64[?])') self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') with tf.Graph().as_default() as g: tf.import_graph_def(serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), name='') with tf.compat.v1.Session(graph=g) as sess: sess.run(fetches=comp.tensorflow.initialize_op) results = sess.run( fetches=comp.tensorflow.result.tensor.tensor_name, feed_dict={ comp.tensorflow.parameter.tensor.tensor_name: ['b', 'c', 'a'] }) self.assertAllEqual(results, [1, 2, 0])
def test_tf_wrapper_with_tf_add(self): foo = computation_wrapper_instances.tensorflow_wrapper( tf.add, (tf.int32, tf.int32)) self.assertEqual(str(foo.type_signature), '(<int32,int32> -> int32)') # TODO(b/113112885): Remove this protected member access as noted above. comp = foo._computation_proto # pylint: disable=protected-access self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') x = tf.compat.v1.placeholder(tf.int32) y = tf.compat.v1.placeholder(tf.int32) result = tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), { comp.tensorflow.parameter.tuple.element[0].tensor.tensor_name: x, comp.tensorflow.parameter.tuple.element[1].tensor.tensor_name: y }, [comp.tensorflow.result.tensor.tensor_name]) with self.session() as sess: def _run(n): return sess.run(result, feed_dict={x: n, y: 3}) results = [_run(n) for n in [1, 20, 5, 10, 30]] self.assertEqual(results, [[4], [23], [8], [13], [33]])
def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( len(input_tensor_names), len(args))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(list(zip(input_tensor_names, args))), return_elements=output_tensor_names) if must_pin_function_to_cpu: with tf.device('cpu'): return _import_fn() elif device is not None: with tf.device(device): return _import_fn() else: return _import_fn()
def test_serialize_tensorflow_with_dataset_not_optimized(self): @tf.function def test_foo(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) def legacy_dataset_reducer_example(ds): return test_foo(ds) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) self.assertGraphDoesNotContainOps( graph_def, ['OptimizeDataset', 'OptimizeDatasetV2', 'ModelDataset']) results = tf.compat.v1.Session().run( tf.import_graph_def( graph_def, { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def get_device_placement_in(comp): """Gets counter of device placement for tensorflow compuation `comp`.""" py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) if (not isinstance(comp, building_blocks.CompiledComputation)) or ( comp.proto.WhichOneof('computation') != 'tensorflow'): raise ValueError( 'Please pass a ' '`building_blocks.CompiledComputation` of the ' '`tensorflow` variety to `get_device_placement_in`. (Got ' 'a [{t}]).'.format(t=type(comp))) graph_def = serialization_utils.unpack_graph_def( comp.proto.tensorflow.graph_def) counter = collections.Counter() def _populate_counter_in_function_lib(func_library): for graph_func in func_library.function: counter.update(node.device for node in graph_func.node_def) for graph_func in func_library.gradient: counter.update(node.device for node in graph_func.node_def) counter.update(node.device for node in graph_def.node) _populate_counter_in_function_lib(graph_def.library) return counter
def function_to_wrap(): """No-arg function to import graph def. We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid the leftover placeholders that can result from binding arguments to the imported graphdef via `input_map`. The correct signature will be added to this function later, via the `prune` call below. Returns: Result of importing graphdef backing `comp`. """ graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) # TODO(b/159180073): clean raise after fixing dataset reduce. _check_dataset_reduce_in_multi_gpu(graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), name='') if must_pin_function_to_cpu: with tf.device('cpu'): return _import_fn() elif device is not None: with tf.device(device.name): return _import_fn() else: return _import_fn()
def test_pack_unpack_roundtrip(self): with tf.Graph().as_default() as g: tf.constant(1.0) input_value = g.as_graph_def() any_pb = serialization_utils.pack_graph_def(input_value) output_value = serialization_utils.unpack_graph_def(any_pb) self.assertEqual(input_value, output_value)
def _unpack_proto_into_graph_spec(tf_block_proto): """Packs a TF proto into a `graph_merge.GraphSpec`. Args: tf_block_proto: Instance of `computation_pb2.Computation` with `tensorflow` `computation` attribute. Returns: Instance of `graph_merge.GraphSpec` containing Python representations of the information present in `tf_block_proto`. """ graph = serialization_utils.unpack_graph_def( tf_block_proto.tensorflow.graph_def) graph_init_op_name = tf_block_proto.tensorflow.initialize_op if not graph_init_op_name: graph_init_op_name = None graph_parameter_binding = tf_block_proto.tensorflow.parameter graph_result_binding = tf_block_proto.tensorflow.result if graph_parameter_binding.WhichOneof('binding') is not None: graph_parameter_list = graph_utils.extract_tensor_names_from_binding( graph_parameter_binding) else: graph_parameter_list = [] graph_result_list = graph_utils.extract_tensor_names_from_binding( graph_result_binding) return graph_merge.GraphSpec(graph, graph_init_op_name, graph_parameter_list, graph_result_list)
def count_tensorflow_variables_in(comp): """Counts TF Variables in `comp` if `comp` is a TF block.""" py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) if (not isinstance(comp, building_blocks.CompiledComputation)) or ( comp.proto.WhichOneof('computation') != 'tensorflow'): raise ValueError( 'Please pass a ' '`building_blocks.CompiledComputation` of the ' '`tensorflow` variety to `count_tensorflow_variables_in`.') graph_def = serialization_utils.unpack_graph_def( comp.proto.tensorflow.graph_def) def _node_is_variable(node): # TODO(b/137887596): Follow up on ways to count Variables on the GraphDef # level. op_name = str(node.op).lower() return ((op_name.startswith('variable') and op_name not in ['variableshape']) or op_name == 'varhandleop') def _count_vars_in_function_lib(func_library): total_nodes = 0 for graph_func in func_library.function: total_nodes += sum( _node_is_variable(node) for node in graph_func.node_def) return total_nodes return (sum(_node_is_variable(node) for node in graph_def.node) + _count_vars_in_function_lib(graph_def.library))
def test_serialize_tensorflow_with_no_parameter(self): tf_proto, _ = self.assert_serializes(lambda _: tf.constant(99), None, '( -> int32)') results = tf.compat.v1.Session().run( tf.graph_util.import_graph_def( serialization_utils.unpack_graph_def(tf_proto.graph_def), None, [tf_proto.result.tensor.tensor_name])) self.assertEqual(results, [99])
def function_to_wrap(*args): if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) return tf.import_graph_def( graph_def, input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names)
def _ensure_comp_runtime_compatible(comp: pb.Computation) -> pb.Computation: """Ensures `comp` is compatible with eager runtime backing EagerExecutor.""" original_tf = comp.tensorflow graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def) # TODO(b/159180073): clean raise after fixing dataset reduce. num_gpu_devices = len(tf.config.list_logical_devices('GPU')) if num_gpu_devices > 1: _check_dataset_reduce_for_multi_gpu(graph_def) return comp
def test_serialize_tensorflow_with_no_parameter(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '( -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') results = tf.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), None, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [99])
def count_tensorflow_ops_in(comp): """Counts TF ops in `comp` if `comp` is a TF block.""" py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock) if (not isinstance(comp, building_blocks.CompiledComputation)) or ( comp.proto.WhichOneof('computation') != 'tensorflow'): raise ValueError('Please pass a ' '`building_blocks.CompiledComputation` of the ' '`tensorflow` variety to `count_tensorflow_ops_in`.') graph_def = serialization_utils.unpack_graph_def( comp.proto.tensorflow.graph_def) return len(graph_def.node)
def test_serialize_tensorflow_with_simple_add_three_lambda(self): tf_proto, _ = self.assert_serializes(lambda x: x + 3, computation_types.TensorType(tf.int32), '(int32 -> int32)') parameter = tf.constant(1000) results = tf.compat.v1.Session().run( tf.graph_util.import_graph_def( serialization_utils.unpack_graph_def(tf_proto.graph_def), {tf_proto.parameter.tensor.tensor_name: parameter}, [tf_proto.result.tensor.tensor_name])) self.assertEqual(results, [1003])
def _extract_call_ops(comp: pb.Computation) -> Iterator[tf.compat.v1.NodeDef]: computation_oneof = comp.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError('`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the "tensorflow" variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) all_nodes = itertools.chain(graph_def.node, *[f.node_def for f in graph_def.library.function]) for node in all_nodes: if node.op in tensorflow_computation_transformations.CALL_OPS: yield node
def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = graph_utils.add_control_deps_for_init_op(graph_def, init_op) return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names)
def test_serialize_tensorflow_with_simple_add_three_lambda(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.constant(1000) results = tf.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), {comp.tensorflow.parameter.tensor.tensor_name: parameter}, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [1003])
def _check_ops(proto: computation_pb2.Computation, allowed_op_names: Optional[FrozenSet[str]] = None, disallowed_op_names: Optional[FrozenSet[str]] = None): """Checks the ops in the TensorFlow computation. If allowed_op_names is specified, then _check_ops checks the incoming proto contains only ops in the set. On the other hand, if disallowed_op_names is specified, then _check_ops checks the proto contains no ops contained in the set. One of the two op set arguments must be non-empty, and if both are, then allowed_op_names takes precedent. Args: proto: Instance of `computation_pb2.Computation` with the `tensorflow` field populated. allowed_op_names: Set of allowed op names. disallowed_op_names: Set of disallowed op names. Raises: DisallowedOpInTensorFlowComputationError: If the computation contains a disallowed op. RuntimeError: If both allowed_op_names and disallowed_op_names are empty. """ py_typecheck.check_type(proto, computation_pb2.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError('`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the "tensorflow" variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) graph_def = serialization_utils.unpack_graph_def( proto.tensorflow.graph_def) all_nodes = itertools.chain( graph_def.node, *[f.node_def for f in graph_def.library.function]) found_disallowed_op_names = set() if allowed_op_names: for node in all_nodes: if node.op not in allowed_op_names: found_disallowed_op_names.add(node.op) elif disallowed_op_names: for node in all_nodes: if node.op in disallowed_op_names: found_disallowed_op_names.add(node.op) else: raise RuntimeError( 'One of allowed_op_names or disallowed_op_names must be non-empty') if found_disallowed_op_names: found_disallowed_op_names_str = ', '.join(found_disallowed_op_names) raise DisallowedOpInTensorFlowComputationError( f'Found disallowed ops: {found_disallowed_op_names_str}')
def count_tensorflow_variables_in(comp): """Counts TF Variables in `comp` if `comp` is a TF block.""" py_typecheck.check_type(comp, computation_building_blocks.ComputationBuildingBlock) if (not isinstance(comp, computation_building_blocks.CompiledComputation) ) or (comp.proto.WhichOneof('computation') != 'tensorflow'): raise ValueError('Please pass a ' '`computation_building_blocks.CompiledComputation` of the ' '`tensorflow` variety to `count_tensorflow_variables_in`.') graph_def = serialization_utils.unpack_graph_def( comp.proto.tensorflow.graph_def) # TODO(b/137887596): Follow up on ways to count Variables on the GraphDef # level. return len([x for x in graph_def.node if 'variable' in str(x.op).lower()])
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) tf_proto, _ = self.assert_serializes( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), '(int64* -> int64)') parameter = tf.data.Dataset.range(5) results = tf.compat.v1.Session().run( tf.graph_util.import_graph_def( serialization_utils.unpack_graph_def(tf_proto.graph_def), { tf_proto.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [tf_proto.result.tensor.tensor_name])) self.assertEqual(results, [10])
def disable_grappler_for_partitioned_calls(proto): """Disables grappler for `PartitionedCall` and `StatefulPartitionedCall` nodes in the graph. TensorFlow serializes a `ConfigProto` into `PartitionedCall` and `StatefulPartitionedCall` the `config_proto` `attr` of graph nodes. This overrides any session config that might disable runtime grappler. The disable grappler for these nodes as well, this function overwrites the serialized configproto, setting the `disable_meta_optimizer` field to `True. Args: proto: Instance of `computation_pb2.Computation` with the `tensorflow` field populated. Returns: A transformed instance of `computation_pb2.Computation` with a `tensorflow` field. """ py_typecheck.check_type(proto, computation_pb2.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError('`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the "tensorflow" variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) original_tf = proto.tensorflow graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def) all_nodes = itertools.chain( graph_def.node, *[f.node_def for f in graph_def.library.function]) for node in all_nodes: if node.op not in CALL_OPS: continue attr_str = node.attr.get('config_proto') if attr_str is None: config_proto = tf.compat.v1.ConfigProto() else: config_proto = tf.compat.v1.ConfigProto.FromString(attr_str.s) config_proto.graph_options.rewrite_options.disable_meta_optimizer = True attr_str.s = config_proto.SerializeToString(deterministic=True) tf_block = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph_def), initialize_op=original_tf.initialize_op if original_tf.initialize_op else None, parameter=original_tf.parameter if original_tf.HasField('parameter') else None, result=original_tf.result) new_proto = computation_pb2.Computation(type=proto.type, tensorflow=tf_block) return new_proto
def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op init_names = [init_op] if init_op else [] returned_elements = tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names + init_names) if init_names: with tf.control_dependencies([returned_elements[-1]]): return [tf.identity(x) for x in returned_elements[0:-1]] else: return returned_elements
def prune_tensorflow_proto(proto): """Extracts subgraph from `proto` preserving parameter, result and initialize. Args: proto: Instance of `pb.Computation` of the `tensorflow` variety whose `graphdef` attribute we wish to prune of extraneous ops. Returns: A transformed instance of `pb.Computation` of the `tensorflow` variety, whose `graphdef` attribute contains only ops which can reach the parameter or result bindings, or initialize op. """ py_typecheck.check_type(proto, pb.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError( '`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the \'tensorflow\' variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) if proto.tensorflow.parameter.WhichOneof('binding'): parameter_tensor_names = graph_utils.extract_tensor_names_from_binding( proto.tensorflow.parameter) parameter_names = [ ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names ] else: parameter_names = [] return_tensor_names = graph_utils.extract_tensor_names_from_binding( proto.tensorflow.result) return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names] graph_def = serialization_utils.unpack_graph_def( proto.tensorflow.graph_def) init_op_name = proto.tensorflow.initialize_op names_to_preserve = parameter_names + return_names if init_op_name: names_to_preserve.append(init_op_name) subgraph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, names_to_preserve) tf_block = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(subgraph_def), initialize_op=proto.tensorflow.initialize_op, parameter=proto.tensorflow.parameter, result=proto.tensorflow.result) pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block) return pruned_proto
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), { comp.tensorflow.parameter.sequence.iterator_string_handle_name: (parameter.make_one_shot_iterator().string_handle()) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def test_tf_wrapper_with_one_op_py_fn(self): @computation_wrapper_instances.tensorflow_wrapper(tf.int32) def foo(x): return x > 10 self.assertEqual(str(foo.type_signature), '(int32 -> bool)') # TODO(b/113112885): Remove this protected member access once the part of # the infrastructure that deals with invoking functions is present. At this # point, extracting the proto from within 'foo' is the only way to test the # wrapper works as intended. comp = foo._computation_proto # pylint: disable=protected-access self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') x = tf.compat.v1.placeholder(tf.int32) result = tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), {comp.tensorflow.parameter.tensor.tensor_name: x}, [comp.tensorflow.result.tensor.tensor_name]) self.assertEqual( list(tf.compat.v1.Session().run(result, feed_dict={x: n}) for n in [1, 20, 5, 10, 30]), [[False], [True], [False], [False], [True]])