def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_indexing_operator( operand_type: computation_types.TensorType, index_type: computation_types.TensorType, ) -> ComputationProtoAndType: """Returns a tensorflow computation computing an indexing operation.""" operand_type.check_tensor() index_type.check_tensor() if index_type.shape.rank != 0: raise TypeError( f'Expected index type to be a scalar, found {index_type}.') with tf.Graph().as_default() as graph: operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph( 'indexing_operand', operand_type, graph) index_value, index_binding = tensorflow_utils.stamp_parameter_in_graph( 'index', index_type, graph) result_value = tf.gather(operand_value, index_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( computation_types.StructType((operand_type, index_type)), result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_binding, index_binding])) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_replicate_input(type_spec, count: int) -> pb.Computation: """Returns a tensorflow computation which returns `count` clones of an input. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_spec` and the length of the result is `count`. Args: type_spec: A type convertible to instance of `computation_types.Type` via `computation_types.to_type`. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_spec` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_spec = computation_types.to_type(type_spec) type_analysis.check_tensorflow_compatible_type(type_spec) py_typecheck.check_type(count, int) with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_spec, graph) result = [parameter_value] * count result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(type_spec, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def test_stateful_partitioned_call_nodes(self): with tf.Graph().as_default() as graph: v = tf.Variable(0) @tf.function def test(): return v.assign_add(1) result_type, result_binding = tensorflow_utils.capture_result_from_graph( test(), graph) function_type = computation_types.FunctionType(None, result_type) serialized_function_type = type_serialization.serialize_type(function_type) proto = pb.Computation( type=serialized_function_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding)) self.assertCallOpsGrapplerNotDisabled(proto) transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( proto) self.assertCallOpsGrapplerDisabled(transformed_proto)
def create_computation_for_py_fn( fn: types.FunctionType, parameter_type: Optional[computation_types.Type]) -> ProtoAndType: """Returns a tensorflow computation returning the result of `fn`. The returned computation has the type signature `(T -> U)`, where `T` is `parameter_type` and `U` is the type returned by `fn`. Args: fn: A Python function. parameter_type: A `computation_types.Type` or `None`. """ py_typecheck.check_type(fn, types.FunctionType) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = fn(parameter_value) else: parameter_binding = None result = fn() result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_identity(type_spec) -> pb.Computation: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_spec`. Args: type_spec: A type convertible to instance of `computation_types.Type` via `computation_types.to_type`. Raises: TypeError: If `type_spec` contains any types which cannot appear in TensorFlow bindings. """ type_spec = computation_types.to_type(type_spec) type_analysis.check_tensorflow_compatible_type(type_spec) with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_spec, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(type_spec, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def test_gets_all_explicit_placement(self): with tf.Graph().as_default() as g: with tf.device('/cpu:0'): a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(sorted(device_placements.keys())) # Expect two placements, the explicit 'cpu' from above, and the empty # placement of the `tf.identity` op add to the captured result. self.assertLen(all_device_placements, 2) self.assertEqual('', sorted(all_device_placements)[0]) self.assertIn('CPU', sorted(all_device_placements)[1]) self.assertGreater(device_placements[all_device_placements[1]], 0)
def test_counts_correct_variables_with_function(self): @tf.function def add_one(x): with tf.init_scope(): y = tf.Variable(1) return x + y with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', tf.int32, graph) result = add_one(add_one(parameter_value)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(tf.int32, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) proto = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 1)
def test_gets_some_explicit_some_none_placement(self): with tf.Graph().as_default() as g: with tf.device('/cpu:0'): a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(device_placements.keys()) self.assertLen(all_device_placements, 2) if all_device_placements[0]: self.assertIn('CPU', all_device_placements[0]) self.assertEqual('', all_device_placements[1]) else: self.assertIn('CPU', all_device_placements[1]) self.assertEqual('', all_device_placements[0]) self.assertGreater(device_placements[all_device_placements[0]], 0) self.assertGreater(device_placements[all_device_placements[1]], 0)
def create_dummy_computation_tensorflow_add(): """Returns a tensorflow computation and type. `(<float32,float32> -> float32)` """ type_spec = tf.float32 with tf.Graph().as_default() as graph: parameter_1_value, parameter_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_spec, graph) parameter_2_value, parameter_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_spec, graph) result_value = tf.add(parameter_1_value, parameter_2_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) parameter_type = computation_types.StructType([type_spec, type_spec]) type_signature = computation_types.FunctionType(parameter_type, result_type) struct_binding = pb.TensorFlow.StructBinding( element=[parameter_1_binding, parameter_2_binding]) parameter_binding = pb.TensorFlow.Binding(struct=struct_binding) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def test_invalid_ops(self): @tf.function def test(): return tf.constant(1) with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( test(), graph) function_type = computation_types.FunctionType(None, result_type) serialized_function_type = type_serialization.serialize_type( function_type) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding)) disallowed_op_names = frozenset(['Const']) with self.assertRaises(tensorflow_computation_transformations. DisallowedOpInTensorFlowComputationError): tensorflow_computation_transformations.check_no_disallowed_ops( proto, disallowed_op_names)
def create_dummy_computation_tensorflow_tuple(value=10.0): """Returns a tensorflow computation and type. `( -> <('a', T), ('b', T), ('c', T)>)` Args: value: An optional integer value. """ with tf.Graph().as_default() as graph: names = ['a', 'b', 'c'] result = anonymous_tuple.AnonymousTuple( (n, tf.constant(value)) for n in names) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def create_replicate_input(type_signature: computation_types.Type, count: int) -> ProtoAndType: """Returns a tensorflow computation returning `count` copies of its argument. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_signature` and the length of the result is `count`. Args: type_signature: A `computation_types.Type` to replicate. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_analysis.check_tensorflow_compatible_type(type_signature) py_typecheck.check_type(count, int) parameter_type = type_signature with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = [parameter_value] * count result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def _pack_noarg_graph(graph_def, return_type, result_binding): packed_graph_def = serialization_utils.pack_graph_def(graph_def) function_type = computation_types.FunctionType(None, return_type) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) return building_block
def create_replicate_input(type_signature: computation_types.Type, count: int) -> ComputationProtoAndType: """Returns a tensorflow computation returning `count` copies of its argument. The returned computation has the type signature `(T -> <T, T, T, ...>)`, where `T` is `type_signature` and the length of the result is `count`. Args: type_signature: A `computation_types.Type` to replicate. count: An integer, the number of times the input is replicated. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings or if `which` is not an integer. """ type_analysis.check_tensorflow_compatible_type(type_signature) py_typecheck.check_type(count, int) parameter_type = type_signature identity_comp, _ = create_identity(parameter_type) # This manual proto manipulation is significantly faster than using TFF's # GraphDef serialization for large `count` arguments. tensorflow_comp = identity_comp.tensorflow single_result_binding = tensorflow_comp.result if tensorflow_comp.parameter: new_tf_pb = pb.TensorFlow( graph_def=tensorflow_comp.graph_def, parameter=tensorflow_comp.parameter, result=pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=(single_result_binding for _ in range(count))))) else: new_tf_pb = pb.TensorFlow( graph_def=tensorflow_comp.graph_def, result=pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=(single_result_binding for _ in range(count))))) fn_type = computation_types.FunctionType( parameter_type, computation_types.StructType([(None, parameter_type) for _ in range(count) ])) return _tensorflow_comp(new_tf_pb, fn_type)
def wrap_graph_parameter_as_tuple(comp, name=None): """Wraps the parameter of `comp` in a tuple binding. `wrap_graph_parameter_as_tuple` is intended as a preprocessing step to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type` can make the assumption that its argument `comp` always has a tuple binding, instead of dealing with the possibility of an unwrapped tensor or sequence binding. Args: comp: Instance of `computation_building_blocks.CompiledComputation` whose parameter we wish to wrap in a tuple binding. name: Optional string argument, the name to assign to the element type in the constructed tuple. Defaults to `None`. Returns: A transformed version of comp representing exactly the same computation, but accepting a tuple containing one element--the parameter of `comp`. Raises: TypeError: If `comp` is not a `computation_building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) if name is not None: py_typecheck.check_type(name, six.string_types) proto = comp.proto proto_type = type_serialization.deserialize_type(proto.type) parameter_binding = [proto.tensorflow.parameter] parameter_type_list = [(name, proto_type.parameter)] new_parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding)) new_function_type = computation_types.FunctionType(parameter_type_list, proto_type.result) serialized_type = type_serialization.serialize_type(new_function_type) input_padded_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def, initialize_op=proto.tensorflow.initialize_op, parameter=new_parameter_binding, result=proto.tensorflow.result)) return computation_building_blocks.CompiledComputation(input_padded_proto)
def create_empty_tuple() -> ProtoAndType: """Returns a tensorflow computation returning an empty tuple. The returned computation has the type signature `( -> <>)`. """ with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( structure.Struct([]), graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_dummy_computation_tensorflow_empty(): """Returns a tensorflow computation and type `( -> <>)`.""" with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( [], graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def disable_grappler_for_partitioned_calls(proto): """Disables grappler for `PartitionedCall` and `StatefulPartitionedCall` nodes in the graph. TensorFlow serializes a `ConfigProto` into `PartitionedCall` and `StatefulPartitionedCall` the `config_proto` `attr` of graph nodes. This overrides any session config that might disable runtime grappler. The disable grappler for these nodes as well, this function overwrites the serialized configproto, setting the `disable_meta_optimizer` field to `True. Args: proto: Instance of `computation_pb2.Computation` with the `tensorflow` field populated. Returns: A transformed instance of `computation_pb2.Computation` with a `tensorflow` field. """ py_typecheck.check_type(proto, computation_pb2.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError('`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the "tensorflow" variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) original_tf = proto.tensorflow graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def) all_nodes = itertools.chain( graph_def.node, *[f.node_def for f in graph_def.library.function]) for node in all_nodes: if node.op not in CALL_OPS: continue attr_str = node.attr.get('config_proto') if attr_str is None: config_proto = tf.compat.v1.ConfigProto() else: config_proto = tf.compat.v1.ConfigProto.FromString(attr_str.s) config_proto.graph_options.rewrite_options.disable_meta_optimizer = True attr_str.s = config_proto.SerializeToString(deterministic=True) tf_block = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph_def), initialize_op=original_tf.initialize_op if original_tf.initialize_op else None, parameter=original_tf.parameter if original_tf.HasField('parameter') else None, result=original_tf.result) new_proto = computation_pb2.Computation(type=proto.type, tensorflow=tf_block) return new_proto
def test_get_wrapped_function_from_comp_raises_with_incorrect_binding(self): with tf.Graph().as_default() as graph: var = tf.Variable(initial_value=0.0, name='var1', import_scope='') assign_op = var.assign_add(tf.constant(1.0)) tf.add(1.0, assign_op) result_binding = pb.TensorFlow.Binding( tensor=pb.TensorFlow.TensorBinding(tensor_name='Invalid')) comp = pb.Computation( tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), result=result_binding)) with self.assertRaises(TypeError): wrapped_fn = eager_tf_executor._get_wrapped_function_from_comp( comp, must_pin_function_to_cpu=False, param_type=None, device=None) wrapped_fn()
def prune_tensorflow_proto(proto): """Extracts subgraph from `proto` preserving parameter, result and initialize. Args: proto: Instance of `pb.Computation` of the `tensorflow` variety whose `graphdef` attribute we wish to prune of extraneous ops. Returns: A transformed instance of `pb.Computation` of the `tensorflow` variety, whose `graphdef` attribute contains only ops which can reach the parameter or result bindings, or initialize op. """ py_typecheck.check_type(proto, pb.Computation) computation_oneof = proto.WhichOneof('computation') if computation_oneof != 'tensorflow': raise TypeError( '`prune_tensorflow_proto` only accepts `Computation` ' 'protos of the \'tensorflow\' variety; you have passed ' 'one of variety {}.'.format(computation_oneof)) if proto.tensorflow.parameter.WhichOneof('binding'): parameter_tensor_names = graph_utils.extract_tensor_names_from_binding( proto.tensorflow.parameter) parameter_names = [ ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names ] else: parameter_names = [] return_tensor_names = graph_utils.extract_tensor_names_from_binding( proto.tensorflow.result) return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names] graph_def = serialization_utils.unpack_graph_def( proto.tensorflow.graph_def) init_op_name = proto.tensorflow.initialize_op names_to_preserve = parameter_names + return_names if init_op_name: names_to_preserve.append(init_op_name) subgraph_def = tf.compat.v1.graph_util.extract_sub_graph( graph_def, names_to_preserve) tf_block = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(subgraph_def), initialize_op=proto.tensorflow.initialize_op, parameter=proto.tensorflow.parameter, result=proto.tensorflow.result) pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block) return pruned_proto
def create_dummy_computation_tensorflow_identity(type_spec=tf.int32): """Returns a tensorflow computation and type `(T -> T)`.""" with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'a', type_spec, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(type_spec, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def transform(self, comp): if not self.should_transform(comp): return comp, False py_typecheck.check_type(comp, building_blocks.CompiledComputation) new_tf_proto = computation_pb2.TensorFlow() new_tf_proto.CopyFrom(comp.proto.tensorflow) # Important: we must also serialize the type_signature because TFF might # produce (<> -> <>) or (<> -> <<>>) functions, which both could be # represented as the same graph with a single NoOp node. This can occur # particularly in MapReduceForm compiltion for secure_sum intrinsics over # empty structures. hash_value = hash( (comp.type_signature, comp.proto.tensorflow.graph_def.value)) new_tf_proto.cache_key.id = ctypes.c_uint64(hash_value).value new_comp_proto = computation_pb2.Computation( type=comp.proto.type, tensorflow=new_tf_proto) return building_blocks.CompiledComputation( new_comp_proto, type_signature=comp.type_signature), True
def create_empty_tuple() -> pb.Computation: """Returns a tensorflow computation returning an empty tuple. The returned computation has the type signature `( -> <>)`. """ with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( [], graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def test_counts_no_variables(self): with tf.Graph().as_default() as g: a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 0)
def create_unary_operator( operator, operand_type: computation_types.Type) -> ComputationProtoAndType: """Returns a tensorflow computation computing a unary operation. The returned computation has the type signature `(T -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a value of type `T` Args: operator: A callable taking one argument representing the operation to encode For example: `tf.math.abs`. operand_type: A `computation_types.Type` to use as the argument to the constructed unary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if (operand_type is None or not type_analysis.is_generic_op_compatible_type(operand_type)): raise TypeError( '`operand_type` contains a type other than ' '`computation_types.TensorType` and `computation_types.StructType`; ' f'this is disallowed in the generic operators. Got: {operand_type} ' ) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) result_value = operator(operand_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(operand_type, result_type) parameter_binding = operand_binding tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def _create_proto_with_unnecessary_op(): parameter_type = tf.int32 with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) unnecessary_op = tf.constant(0) tensorflow_utils.capture_result_from_graph(unnecessary_op, graph) function_type = computation_types.FunctionType(parameter_type, result_type) serialized_function_type = type_serialization.serialize_type(function_type) return pb.Computation(type=serialized_function_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding))
def test_counts_two_variables_correctly(self): with tf.Graph().as_default() as g: a = tf.Variable(0, name='variable1') b = tf.Variable(1, name='variable2') c = a + b _, result_binding = graph_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = computation_building_blocks.ComputationBuildingBlock.from_proto( proto) tf_vars_in_graph = computation_building_block_utils.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 2)
def test_counts_correct_number_of_ops_simple_case(self): with tf.Graph().as_default() as g: a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) # Expect 4 ops: two constants, one addition, and an identity on the result. self.assertEqual(tf_ops_in_graph, 4)
def create_whimsy_computation_tensorflow_tuple(): """Returns a tensorflow computation and type. `( -> <('a', float32), ('b', float32), ('c', float32)>)` """ value = 10.0 with tf.Graph().as_default() as graph: names = ['a', 'b', 'c'] result = structure.Struct((n, tf.constant(value)) for n in names) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature