def _create_two_variable_tensorflow(): with tf.Graph().as_default() as g: a = tf.Variable(0, name='variable1') b = tf.Variable(1, name='variable2') c = a + b result_type, result_binding = tensorflow_utils.capture_result_from_graph( c, g) return _pack_noarg_graph(g.as_graph_def(), result_type, result_binding)
def _create_proto_with_unnecessary_op(): parameter_type = tf.int32 with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) unnecessary_op = tf.constant(0) tensorflow_utils.capture_result_from_graph(unnecessary_op, graph) function_type = computation_types.FunctionType(parameter_type, result_type) serialized_function_type = type_serialization.serialize_type(function_type) return pb.Computation(type=serialized_function_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding))
def test_capture_result_with_attrs_of_constants(self): @attr.s class TestFoo(object): x = attr.ib() y = attr.ib() graph = tf.compat.v1.get_default_graph() type_spec, _ = tensorflow_utils.capture_result_from_graph( TestFoo(tf.constant(1), tf.constant(True)), graph) self.assertEqual(str(type_spec), '<x=int32,y=bool>') self.assertIs(type_spec.python_container, TestFoo)
def test_capture_result_with_ragged_tensor(self): with tf.Graph().as_default() as graph: type_spec, binding = tensorflow_utils.capture_result_from_graph( tf.RaggedTensor.from_row_splits([0, 0, 0, 0], [0, 1, 4]), graph) del binding self.assert_types_identical( type_spec, computation_types.StructWithPythonType([ ('flat_values', computation_types.TensorType(tf.int32, [4])), ('nested_row_splits', computation_types.StructWithPythonType([ (None, computation_types.TensorType(tf.int64, [3])) ], tuple)), ], tf.RaggedTensor))
def test_capture_result_with_attrs_of_constants(self): @attr.s class TestFoo(object): x = attr.ib() y = attr.ib() graph = tf.compat.v1.get_default_graph() type_spec, _ = tensorflow_utils.capture_result_from_graph( TestFoo(tf.constant(1), tf.constant(True)), graph) self.assertEqual(str(type_spec), '<x=int32,y=bool>') self.assertIsInstance( type_spec, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(type_spec), TestFoo)
def create_dummy_computation_tensorflow_empty(): """Returns a tensorflow computation and type `( -> <>)`.""" with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( [], graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def create_empty_tuple() -> ProtoAndType: """Returns a tensorflow computation returning an empty tuple. The returned computation has the type signature `( -> <>)`. """ with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( structure.Struct([]), graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_empty_tuple() -> pb.Computation: """Returns a tensorflow computation returning an empty tuple. The returned computation has the type signature `( -> <>)`. """ with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( [], graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def create_dummy_computation_tensorflow_identity(type_spec=tf.int32): """Returns a tensorflow computation and type `(T -> T)`.""" with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'a', type_spec, graph) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(type_spec, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def test_counts_no_variables(self): with tf.Graph().as_default() as g: a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in( building_block) self.assertEqual(tf_vars_in_graph, 0)
def create_unary_operator( operator, operand_type: computation_types.Type) -> ComputationProtoAndType: """Returns a tensorflow computation computing a unary operation. The returned computation has the type signature `(T -> U)`, where `T` is `operand_type` and `U` is the result of applying the `operator` to a value of type `T` Args: operator: A callable taking one argument representing the operation to encode For example: `tf.math.abs`. operand_type: A `computation_types.Type` to use as the argument to the constructed unary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `operand_type` are violated or `operator` is not callable. """ if (operand_type is None or not type_analysis.is_generic_op_compatible_type(operand_type)): raise TypeError( '`operand_type` contains a type other than ' '`computation_types.TensorType` and `computation_types.StructType`; ' f'this is disallowed in the generic operators. Got: {operand_type} ' ) py_typecheck.check_callable(operator) with tf.Graph().as_default() as graph: operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', operand_type, graph) result_value = operator(operand_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(operand_type, result_type) parameter_binding = operand_binding tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def test_gets_none_placement(self): with tf.Graph().as_default() as g: a = tf.Variable(0, name='variable1') b = tf.Variable(1, name='variable2') c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(device_placements.keys()) self.assertLen(all_device_placements, 1) self.assertEqual(all_device_placements[0], '') self.assertGreater(device_placements[''], 0)
def test_counts_correct_number_of_ops_simple_case(self): with tf.Graph().as_default() as g: a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow(graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) # Expect 4 ops: two constants, one addition, and an identity on the result. self.assertEqual(tf_ops_in_graph, 4)
def create_dummy_empty_tensorflow_computation(): """Returns a `pb.Computation` representing an tensorflow graph. The type signature of this `pb.Computation` is: ( -> <>) Returns: A `pb.Computation`. """ with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( [], graph) function_type = computation_types.FunctionType(None, result_type) type_signature = type_serialization.serialize_type(function_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) return pb.Computation(type=type_signature, tensorflow=tensorflow)
def create_dummy_computation_tensorflow_tuple(): """Returns a tensorflow computation and type. `( -> <('a', float32), ('b', float32), ('c', float32)>)` """ value = 10.0 with tf.Graph().as_default() as graph: names = ['a', 'b', 'c'] result = structure.Struct((n, tf.constant(value)) for n in names) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) value = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) return value, type_signature
def create_identity(type_signature: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation representing an identity function. The returned computation has the type signature `(T -> T)`, where `T` is `type_signature`. NOTE: if `T` contains `computation_types.StructType`s without an associated container type, they will be given the container type `tuple` by this function. Args: type_signature: A `computation_types.Type` to use as the parameter type and result type of the identity function. Raises: TypeError: If `type_signature` contains any types which cannot appear in TensorFlow bindings. """ type_analysis.check_tensorflow_compatible_type(type_signature) parameter_type = type_signature if parameter_type is None: raise TypeError('TensorFlow identity cannot be created for NoneType.') with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) # TF relies on feeds not-identical to fetches in certain circumstances. if type_signature.is_tensor(): parameter_value = tf.identity(parameter_value) elif type_signature.is_struct(): parameter_value = structure.map_structure(tf.identity, parameter_value) result_type, result_binding = tensorflow_utils.capture_result_from_graph( parameter_value, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_broadcast_scalar_to_shape(scalar_type: tf.DType, shape: tf.TensorShape) -> pb.Computation: """Returns a tensorflow computation returning the result of `tf.broadcast_to`. The returned computation has the type signature `(T -> U)`, where `T` is `scalar_type` and the `U` is a `tff.TensorType` with a dtype of `scalar_type` and a `shape`. Args: scalar_type: A `tf.DType`, the type of the scalar to broadcast. shape: A `tf.TensorShape` to broadcast to. Must be fully defined. Raises: TypeError: If `scalar_type` is not a `tf.DType` or if `shape` is not a `tf.TensorShape`. ValueError: If `shape` is not fully defined. """ py_typecheck.check_type(scalar_type, tf.DType) py_typecheck.check_type(shape, tf.TensorShape) shape.assert_is_fully_defined() parameter_type = computation_types.TensorType(scalar_type, shape=()) with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = tf.broadcast_to(parameter_value, shape) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def test_partitioned_call_nodes(self): @tf.function def test(): return tf.constant(1) with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( test(), graph) function_type = computation_types.FunctionType(None, result_type) serialized_function_type = type_serialization.serialize_type(function_type) proto = pb.Computation( type=serialized_function_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding)) self.assertCallOpsGrapplerNotDisabled(proto) transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( proto) self.assertCallOpsGrapplerDisabled(transformed_proto)
def test_gets_all_explicit_placement(self): with tf.Graph().as_default() as g: with tf.device('/cpu:0'): a = tf.constant(0) b = tf.constant(1) c = a + b _, result_binding = tensorflow_utils.capture_result_from_graph(c, g) packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def()) function_type = computation_types.FunctionType(None, tf.int32) proto = pb.Computation( type=type_serialization.serialize_type(function_type), tensorflow=pb.TensorFlow( graph_def=packed_graph_def, parameter=None, result=result_binding)) building_block = building_blocks.ComputationBuildingBlock.from_proto(proto) device_placements = building_block_analysis.get_device_placement_in( building_block) all_device_placements = list(device_placements.keys()) self.assertLen(all_device_placements, 1) self.assertIn('CPU', all_device_placements[0]) self.assertGreater(device_placements[all_device_placements[0]], 0)
def test_valid_ops(self): @tf.function def test(): return tf.constant(1) with tf.Graph().as_default() as graph: result_type, result_binding = tensorflow_utils.capture_result_from_graph( test(), graph) function_type = computation_types.FunctionType(None, result_type) serialized_function_type = type_serialization.serialize_type( function_type) proto = computation_pb2.Computation( type=serialized_function_type, tensorflow=computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding)) disallowed_op_names = frozenset(['ShardedFilename']) tensorflow_computation_transformations.check_no_disallowed_ops( proto, disallowed_op_names)
def test_counts_correct_number_of_ops_with_function(self): @tf.function def add_one(x): return x + 1 with tf.Graph().as_default() as graph: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', tf.int32, graph) result = add_one(add_one(parameter_value)) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(tf.int32, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) proto = pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow) building_block = building_blocks.ComputationBuildingBlock.from_proto( proto) tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in( building_block) # Expect 7 ops: # Inside the tf.function: # - one constant # - one addition # - one identity on the result # Inside the tff_computation: # - one placeholders (one for the argument) # - two partition calls # - one identity on the tff_computation result self.assertEqual(tf_ops_in_graph, 7)
def create_computation_for_py_fn( fn: types.FunctionType, parameter_type: Optional[computation_types.Type]) -> pb.Computation: """Returns a tensorflow computation returning the result of `fn`. The returned computation has the type signature `(T -> U)`, where `T` is `parameter_type` and `U` is the type returned by `fn`. Args: fn: A Python function. parameter_type: A `computation_types.Type`. """ py_typecheck.check_type(fn, types.FunctionType) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', parameter_type, graph) result = fn(parameter_value) else: parameter_binding = None result = fn() result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)
def test_capture_result_with_np_ndarray(self): with tf.Graph().as_default() as graph: type_spec, binding = tensorflow_utils.capture_result_from_graph( np.ndarray(shape=(2, 0), dtype=np.int32), graph) self._assert_captured_result_eq_dtype(type_spec, binding, 'int32[2,0]')
def test_capture_result_with_np_bool(self): with tf.Graph().as_default() as graph: type_spec, binding = tensorflow_utils.capture_result_from_graph( np.bool(True), graph) self._assert_captured_result_eq_dtype(type_spec, binding, 'bool')
def test_capture_result_with_np_float64(self): with tf.Graph().as_default() as graph: type_spec, binding = tensorflow_utils.capture_result_from_graph( np.float64(1.0), graph) self._assert_captured_result_eq_dtype(type_spec, binding, 'float64')
def test_capture_result_with_int(self): with tf.Graph().as_default() as graph: type_spec, binding = tensorflow_utils.capture_result_from_graph(1, graph) self._assert_captured_result_eq_dtype(type_spec, binding, 'int32')
def tf_computation_serializer(parameter_type: Optional[computation_types.Type], context_stack): """Serializes a TF computation with a given parameter type. Args: parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `computation_types.Type`. context_stack: The context stack to use. Yields: The first yielded value will be a Python object (such as a dataset, a placeholder, or a `structure.Struct`) to be passed to the function to serialize. The result of the function should then be passed to the following `send` call. The next yielded value will be a tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'arg', parameter_type, graph) else: parameter_value = None parameter_binding = None context = tensorflow_computation_context.TensorFlowComputationContext( graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope( ) as all_variables: result = yield parameter_value initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) initializer_ops.append( tf.compat.v1.initializers.variables(all_variables, name=name)) initializer_ops.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) if initializer_ops: # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. with tf.compat.v1.control_dependencies(context.init_ops): init_op_name = tf.group(*initializer_ops, name='grouped_initializers').name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name) yield pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow), type_signature
def create_constant(scalar_value, type_spec: computation_types.Type) -> ProtoAndType: """Returns a tensorflow computation returning a constant `scalar_value`. The returned computation has the type signature `( -> T)`, where `T` is `type_spec`. `scalar_value` must be a scalar, and cannot be a float if any of the tensor leaves of `type_spec` contain an integer data type. `type_spec` must contain only named tuples and tensor types, but these can be arbitrarily nested. Args: scalar_value: A scalar value to place in all the tensor leaves of `type_spec`. type_spec: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `type_spec` are violated. """ if not type_analysis.is_generic_op_compatible_type(type_spec): raise TypeError( 'Type spec {} cannot be constructed as a TensorFlow constant in TFF; ' ' only nested tuples and tensors are permitted.'.format(type_spec)) inferred_scalar_value_type = type_conversions.infer_type(scalar_value) if (not inferred_scalar_value_type.is_tensor() or inferred_scalar_value_type.shape != tf.TensorShape(())): raise TypeError( 'Must pass a scalar value to `create_tensorflow_constant`; encountered ' 'a value {}'.format(scalar_value)) tensor_dtypes_in_type_spec = [] def _pack_dtypes(type_signature): """Appends dtype of `type_signature` to nonlocal variable.""" if type_signature.is_tensor(): tensor_dtypes_in_type_spec.append(type_signature.dtype) return type_signature, False type_transformations.transform_type_postorder(type_spec, _pack_dtypes) if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and not inferred_scalar_value_type.dtype.is_integer): raise TypeError( 'Only integers can be used as scalar values if our desired constant ' 'type spec contains any integer tensors; passed scalar {} of dtype {} ' 'for type spec {}.'.format(scalar_value, inferred_scalar_value_type.dtype, type_spec)) result_type = type_spec def _create_result_tensor(type_spec, scalar_value): """Packs `scalar_value` into `type_spec` recursively.""" if type_spec.is_tensor(): type_spec.shape.assert_is_fully_defined() result = tf.constant( scalar_value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] for _, type_element in structure.iter_elements(type_spec): elements.append(_create_result_tensor(type_element, scalar_value)) result = elements return result with tf.Graph().as_default() as graph: result = _create_result_tensor(result_type, scalar_value) _, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def create_binary_operator_with_upcast( type_signature: computation_types.StructType, operator: Callable[[Any, Any], Any]) -> ProtoAndType: """Creates TF computation upcasting its argument and applying `operator`. Args: type_signature: A `computation_types.StructType` with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `building_blocks.CompiledComputation` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_type(type_signature, computation_types.StructType) py_typecheck.check_callable(operator) type_analysis.check_tensorflow_compatible_type(type_signature) if not type_signature.is_struct() or len(type_signature) != 2: raise TypeError('To apply a binary operator, we must by definition have an ' 'argument which is a `StructType` with 2 elements; ' 'asked to create a binary operator for type: {t}'.format( t=type_signature)) if type_analysis.contains(type_signature, lambda t: t.is_sequence()): raise TypeError( 'Applying binary operators in TensorFlow is only ' 'supported on Tensors and StructTypes; you ' 'passed {t} which contains a SequenceType.'.format(t=type_signature)) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape) with tf.Graph().as_default() as graph: first_arg, operand_1_binding = tensorflow_utils.stamp_parameter_in_graph( 'x', type_signature[0], graph) operand_2_value, operand_2_binding = tensorflow_utils.stamp_parameter_in_graph( 'y', type_signature[1], graph) if type_signature[0].is_equivalent_to(type_signature[1]): second_arg = operand_2_value else: second_arg = _pack_into_type(operand_2_value, type_signature[0]) if type_signature[0].is_tensor(): result_value = operator(first_arg, second_arg) elif type_signature[0].is_struct(): result_value = structure.map_structure(operator, first_arg, second_arg) else: raise TypeError('Encountered unexpected type {t}; can only handle Tensor ' 'and StructTypes.'.format(t=type_signature[0])) result_type, result_binding = tensorflow_utils.capture_result_from_graph( result_value, graph) type_signature = computation_types.FunctionType(type_signature, result_type) parameter_binding = pb.TensorFlow.Binding( struct=pb.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding])) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)