def test_invoke_returns_result_with_tf_computation(self): make_10 = computations.tf_computation(lambda: tf.constant(10)) add_one = computations.tf_computation(lambda x: tf.add(x, 1), tf.int32) @computations.tf_computation def add_one_with_v1(x): v1 = tf.Variable(1, name='v1') return x + v1 @computations.tf_computation def add_one_with_v2(x): v2 = tf.Variable(1, name='v2') return x + v2 @computations.tf_computation def foo(): zero = tf.Variable(0, name='zero') ten = tf.Variable(make_10()) return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) + zero + ten - ten) graph = tf.compat.v1.Graph() context = tensorflow_computation_context.TensorFlowComputationContext( graph) self.assertEqual(foo.type_signature.compact_representation(), '( -> int32)') x = context.invoke(foo, None) with tf.compat.v1.Session(graph=graph) as sess: if context.init_ops: sess.run(context.init_ops) result = sess.run(x) self.assertEqual(result, 13)
def test_invoke_raises_value_error_with_federated_computation(self): @computations.federated_computation( computation_types.FederatedType(tf.int32, placement_literals.SERVER, True)) def foo(x): return intrinsics.federated_broadcast(x) context = tensorflow_computation_context.TensorFlowComputationContext( tf.compat.v1.get_default_graph()) with self.assertRaisesRegex(ValueError, 'Expected a TensorFlow computation.'): context.invoke(foo, None)
def test_get_session_token(self): @tensorflow_computation.tf_computation def get_the_token(): return tensorflow_computation_context.get_session_token() with tf.compat.v1.Graph().as_default() as graph: context = tensorflow_computation_context.TensorFlowComputationContext( graph, tf.constant('test_token')) x = context.invoke(get_the_token, None) with tf.compat.v1.Session(graph=graph) as sess: result = sess.run(x) self.assertEqual(result, b'test_token')
def test_invoke_raises_value_error_with_federated_computation(self): bogus_proto = pb.Computation(type=type_serialization.serialize_type( computation_types.to_type( computation_types.FunctionType(tf.int32, tf.int32))), reference=pb.Reference(name='boogledy')) non_tf_computation = computation_impl.ComputationImpl( bogus_proto, context_stack_impl.context_stack) context = tensorflow_computation_context.TensorFlowComputationContext( tf.compat.v1.get_default_graph()) with self.assertRaisesRegex( ValueError, 'Can only invoke TensorFlow in the body of ' 'a TensorFlow computation'): context.invoke(non_tf_computation, None)
def tf_computation_serializer(parameter_type: Optional[computation_types.Type], context_stack): """Serializes a TF computation with a given parameter type. Args: parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `computation_types.Type`. context_stack: The context stack to use. Yields: The first yielded value will be a Python object (such as a dataset, a placeholder, or a `structure.Struct`) to be passed to the function to serialize. The result of the function should then be passed to the following `send` call. The next yielded value will be a tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: py_typecheck.check_type(parameter_type, computation_types.Type) with tf.Graph().as_default() as graph: if parameter_type is not None: parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( 'arg', parameter_type, graph) else: parameter_value = None parameter_binding = None context = tensorflow_computation_context.TensorFlowComputationContext( graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope( ) as all_variables: result = yield parameter_value initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) initializer_ops.append( tf.compat.v1.initializers.variables(all_variables, name=name)) initializer_ops.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) if initializer_ops: # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. with tf.compat.v1.control_dependencies(context.init_ops): init_op_name = tf.group(*initializer_ops, name='grouped_initializers').name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name) yield pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow), type_signature
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. See also `serialize_tf2_as_tf_computation` for TensorFlow 2 serialization. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) parameter_type = computation_types.to_type(parameter_type) signature = function_utils.get_signature(target) with tf.Graph().as_default() as graph: if parameter_type is not None: if len(signature.parameters) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, found {!r}.' .format(signature.parameters)) parameter_name = next(iter(signature.parameters)) parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) else: if signature.parameters: raise ValueError( 'Expected the target to declare no parameters, found {!r}.' .format(signature.parameters)) parameter_value = None parameter_binding = None context = tensorflow_computation_context.TensorFlowComputationContext( graph) with context_stack.install(context): with variable_utils.record_variable_creation_scope( ) as all_variables: if parameter_value is not None: result = target(parameter_value) else: result = target() initializer_ops = [] if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) initializer_ops.append( tf.compat.v1.initializers.variables(all_variables, name=name)) initializer_ops.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) if initializer_ops: # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. with tf.compat.v1.control_dependencies(context.init_ops): init_op_name = tf.group(*initializer_ops, name='grouped_initializers').name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(parameter_type, result_type) # WARNING: we do not really want to be modifying the graph here if we can # avoid it. This is purely to work around performance issues uncovered with # the non-standard usage of Tensorflow and have been discussed with the # Tensorflow core team before being added. clean_graph_def = _clean_graph_def(graph.as_graph_def()) tensorflow = pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(clean_graph_def), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow), type_signature