def create_broadcast_scalar_to_shape(scalar_type: tf.DType, shape: tf.TensorShape) -> pb.Computation: """Returns a tensorflow computation returning the result of `tf.broadcast_to`. The returned computation has the type signature `(T -> U)`, where `T` is `scalar_type` and the `U` is a `tff.TensorType` with a dtype of `scalar_type` and a `shape`. Args: scalar_type: A `tf.DType`, the type of the scalar to broadcast. shape: A `tf.TensorShape` to broadcast to. Must be fully defined. Raises: TypeError: If `scalar_type` is not a `tf.DType` or if `shape` is not a `tf.TensorShape`. ValueError: If `shape` is not fully defined. """ py_typecheck.check_type(scalar_type, tf.DType) py_typecheck.check_type(shape, tf.TensorShape) shape.assert_is_fully_defined() parameter_type = computation_types.TensorType(scalar_type, shape=()) with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = tf.broadcast_to(parameter_value, shape) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def test_counts_correct_number_of_ops_with_function(self): @tf.function def add_one(x): return x + 1 with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', tf.int32, graph) result = add_one(add_one(parameter_value)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(tf.int32, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) proto = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) # Expect 7 ops: # Inside the tf.function: # - one constant # - one addition # - one identity on the result # Inside the tff_computation: # - one placeholders (one for the argument) # - two partition calls # - one identity on the tff_computation result self.assertEqual(tf_ops_in_graph, 7)
def create_computation_for_py_fn( fn: types.FunctionType, parameter_type: Optional[computation_types.Type]) -> pb.Computation: """Returns a tensorflow computation returning the result of `fn`. The returned computation has the type signature `(T -> U)`, where `T` is `parameter_type` and `U` is the type returned by `fn`. Args: fn: A Python function. parameter_type: A `computation_types.Type`. """ py_typecheck.check_type(fn, types.FunctionType) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = fn(parameter_value) else: parameter_binding = None result = fn() result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def tf_computation_serializer(parameter_type: Optional[computation_types.Type], context_stack): """Serializes a TF computation with a given parameter type. Args: parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `computation_types.Type`. context_stack: The context stack to use. Yields: The first yielded value will be a Python object (such as a dataset, a placeholder, or a `structure.Struct`) to be passed to the function to serialize. The result of the function should then be passed to the following `send` call. The next yielded value will be a tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'arg', parameter_type, graph) else: parameter_value = None parameter_binding = None context = tensorflow_computation_context.TensorFlowComputationContext( graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope( ) as all_variables: result = yield parameter_value initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) initializer_ops.append( tf.compat.v1.initializers.variables(all_variables, name=name)) initializer_ops.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) if initializer_ops: # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. with tf.compat.v1.control_dependencies(context.init_ops): init_op_name = tf.group(*initializer_ops, name='grouped_initializers').name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name) yield pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow), type_signature
def create_binary_operator_with_upcast( type_signature: computation_types.StructType, operator: Callable[[Any, Any], Any]) -> ProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.StructType` with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `building_blocks.CompiledComputation` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_type(type_signature, computation_types.StructType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_struct() or len(type_signature) != 2: raise TypeError('To apply a binary operator, we must by definition have an ' 'argument which is a `StructType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError( 'Applying binary operators in TensorFlow is only ' 'supported on Tensors and StructTypes; you ' 'passed {t} which contains a SequenceType.'.format(t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_struct(): result_value = structure.map_structure(operator, first_arg, second_arg) else: raise TypeError('Encountered unexpected type {t}; can only handle Tensor ' 'and StructTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_binary_operator( operator, operand_type: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation computing a binary operation. The returned computation has the type signature `(<T,T> -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a tuple of type `<T,T>` Note: If `operand_type` is a `computation_types.StructType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed into the data structure `[x, x]`, before the division operator of the appropriate type is called. Args: operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. operand_type: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if not type_analysis.is_generic_op_compatible_type(operand_type): raise TypeError( 'The type {} contains a type other than `computation_types.TensorType` ' 'and `computation_types.StructType`; this is disallowed in the ' 'generic operators.'.format(operand_type)) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', operand_type, graph) if operand_type is not None: if operand_type.is_tensor(): result_value = operator(operand_1_value, operand_2_value) elif operand_type.is_struct(): result_value = structure.map_structure(operator, operand_1_value, operand_2_value) else: raise TypeError( 'Operand type {} cannot be used in generic operations. The call to ' '`type_analysis.is_generic_op_compatible_type` has allowed it to ' 'pass, and should be updated.'.format(operand_type)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( computation_types.StructType((operand_type, operand_type)), result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. See also `serialize_tf2_as_tf_computation` for TensorFlow 2 serialization. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) parameter_type = computation_types.to_type(parameter_type) argspec = function_utils.get_argspec(target) with tf.Graph().as_default() as graph: args = [] if parameter_type is not None: if len(argspec.args) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, found {!r}.' .format(argspec.args)) parameter_name = argspec.args[0] parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) args.append(parameter_value) else: if argspec.args: raise ValueError( 'Expected the target to declare no parameters, found {!r}.'.format( argspec.args)) parameter_binding = None context = tf_computation_context.TensorFlowComputationContext(graph) with context_stack.install(context): result = target(*args) # TODO(b/122081673): This needs to change for TF 2.0. We may also # want to allow the person creating a tff.tf_computation to specify # a different initializer; e.g., if it is known that certain # variables will be assigned immediately to arguments of the function, # then it is wasteful to initialize them before this. # # The following is a bit of a work around: the collections below may # contain variables more than once, hence we throw into a set. TFF needs # to ensure all variables are initialized, but not all variables are # always in the collections we expect. tff.learning._KerasModel tries to # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into # VARS_FOR_TFF_TO_INITIALIZE for now. all_variables = set(tf.compat.v1.global_variables() + tf.compat.v1.local_variables() + tf.compat.v1.get_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE)) if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format(len(all_variables)) with tf.control_dependencies(context.init_ops): # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. init_op_name = tf.compat.v1.initializers.variables( all_variables, name=name).name elif context.init_ops: init_op_name = tf.group( *context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation( type=pb.Type( function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name)), annotated_type
def create_binary_operator_with_upcast( type_signature: computation_types.StructType, operator: Callable[[Any, Any], Any]) -> ComputationProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.StructType` with two elements, both only containing structs or tensors in their type tree. The first and second element must match in structure, or the second element may be a single tensor type that is broadcasted (upcast) to the leaves of the structure of the first type. operator: Callable defining the operator. Returns: Same as `create_binary_operator()`. """ py_typecheck.check_type(type_signature, computation_types.StructType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_struct() or len(type_signature) != 2: raise TypeError( 'To apply a binary operator, we must by definition have an ' 'argument which is a `StructType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError('Applying binary operators in TensorFlow is only ' 'supported on Tensors and StructTypes; you ' 'passed {t} which contains a SequenceType.'.format( t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_struct() and type_signature[1].is_struct(): # If both the first and second arguments are structs with the same # structure, simply re-use operand_2_value as. `tf.nest.map_structure` # below will map the binary operator pointwise to the leaves of the # structure. if structure.is_same_structure(type_signature[0], type_signature[1]): second_arg = operand_2_value else: raise TypeError( 'Cannot upcast one structure to a different structure. ' '{x} -> {y}'.format(x=type_signature[1], y=type_signature[0])) elif type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_struct(): result_value = structure.map_structure(operator, first_arg, second_arg) else: raise TypeError( 'Encountered unexpected type {t}; can only handle Tensor ' 'and StructTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. See also `serialize_tf2_as_tf_computation` for TensorFlow 2 serialization. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) parameter_type = computation_types.to_type(parameter_type) signature = function_utils.get_signature(target) with tf.Graph().as_default() as graph: if parameter_type is not None: if len(signature.parameters) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, found {!r}.' .format(signature.parameters)) parameter_name = next(iter(signature.parameters)) parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) else: if signature.parameters: raise ValueError( 'Expected the target to declare no parameters, found {!r}.' .format(signature.parameters)) parameter_value = None parameter_binding = None context = tensorflow_computation_context.TensorFlowComputationContext( graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope( ) as all_variables: if parameter_value is not None: result = target(parameter_value) else: result = target() initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) initializer_ops.append( tf.compat.v1.initializers.variables(all_variables, name=name)) initializer_ops.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) if initializer_ops: # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. with tf.compat.v1.control_dependencies(context.init_ops): init_op_name = tf.group(*initializer_ops, name='grouped_initializers').name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) # WARNING: we do not really want to be modifying the graph here if we can # avoid it. This is purely to work around performance issues uncovered with # the non-standard usage of Tensorflow and have been discussed with the # Tensorflow core team before being added. clean_graph_def = _clean_graph_def(graph.as_graph_def()) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(clean_graph_def), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow), type_signature
def create_binary_operator(operator, operand_type) -> pb.Computation: """Returns a tensorflow computation representing the binary `operator`. The returned computation has the type signature `(<T,T> -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a tuple of type `<T,T>` Note: If `operand_type` is a `computation_types.NamedTupleType`, then `operator` will be applied pointwise. This places the burden on callers of this function to construct the correct values to pass into the returned function. For example, to divide `[2, 2]` by `2`, first `2` must be packed into the data structure `[x, x]`, before the division operator of the appropriate type is called. Args: operator: A callable taking two arguments representing the operation to encode For example: `tf.math.add`, `tf.math.multiply`, and `tf.math.divide`. operand_type: The type of the argument to the constructed binary operator; A type convertible to instance of `computation_types.Type` via `computation_types.to_type` which can only contain types which are compatible with the TFF generic operators (named tuples and tensors). Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ operand_type = computation_types.to_type(operand_type) if not type_utils.is_generic_op_compatible_type(operand_type): raise TypeError( 'The type {} contains a type other than `computation_types.TensorType` ' 'and `computation_types.NamedTupleType`; this is disallowed in the ' 'generic operators.'.format(operand_type)) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_1_value, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', operand_type, graph) if isinstance(operand_type, computation_types.TensorType): result_value = operator(operand_1_value, operand_2_value) elif isinstance(operand_type, computation_types.NamedTupleType): result_value = anonymous_tuple.map_structure( operator, operand_1_value, operand_2_value) else: raise TypeError( 'Operand type {} cannot be used in generic operations. The whitelist ' 'in `type_utils.is_generic_op_compatible_type` has allowed it to ' 'pass, and should be updated.'.format(operand_type)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType( [operand_type, operand_type], result_type) parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)