def test_invoke_returns_result_with_tf_computation(self):
        make_10 = computations.tf_computation(lambda: tf.constant(10))
        add_one = computations.tf_computation(lambda x: tf.add(x, 1), tf.int32)

        @computations.tf_computation
        def add_one_with_v1(x):
            v1 = tf.Variable(1, name='v1')
            return x + v1

        @computations.tf_computation
        def add_one_with_v2(x):
            v2 = tf.Variable(1, name='v2')
            return x + v2

        @computations.tf_computation
        def foo():
            zero = tf.Variable(0, name='zero')
            ten = tf.Variable(make_10())
            return (add_one_with_v2(add_one_with_v1(add_one(make_10()))) +
                    zero + ten - ten)

        graph = tf.compat.v1.Graph()
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)

        self.assertEqual(foo.type_signature.compact_representation(),
                         '( -> int32)')
        x = context.invoke(foo, None)

        with tf.compat.v1.Session(graph=graph) as sess:
            if context.init_ops:
                sess.run(context.init_ops)
            result = sess.run(x)
        self.assertEqual(result, 13)
    def test_invoke_raises_value_error_with_federated_computation(self):
        @computations.federated_computation(
            computation_types.FederatedType(tf.int32,
                                            placement_literals.SERVER, True))
        def foo(x):
            return intrinsics.federated_broadcast(x)

        context = tensorflow_computation_context.TensorFlowComputationContext(
            tf.compat.v1.get_default_graph())

        with self.assertRaisesRegex(ValueError,
                                    'Expected a TensorFlow computation.'):
            context.invoke(foo, None)
  def test_get_session_token(self):

    @tensorflow_computation.tf_computation
    def get_the_token():
      return tensorflow_computation_context.get_session_token()

    with tf.compat.v1.Graph().as_default() as graph:
      context = tensorflow_computation_context.TensorFlowComputationContext(
          graph, tf.constant('test_token'))

    x = context.invoke(get_the_token, None)
    with tf.compat.v1.Session(graph=graph) as sess:
      result = sess.run(x)
    self.assertEqual(result, b'test_token')
Ejemplo n.º 4
0
    def test_invoke_raises_value_error_with_federated_computation(self):
        bogus_proto = pb.Computation(type=type_serialization.serialize_type(
            computation_types.to_type(
                computation_types.FunctionType(tf.int32, tf.int32))),
                                     reference=pb.Reference(name='boogledy'))
        non_tf_computation = computation_impl.ComputationImpl(
            bogus_proto, context_stack_impl.context_stack)

        context = tensorflow_computation_context.TensorFlowComputationContext(
            tf.compat.v1.get_default_graph())

        with self.assertRaisesRegex(
                ValueError, 'Can only invoke TensorFlow in the body of '
                'a TensorFlow computation'):
            context.invoke(non_tf_computation, None)
Ejemplo n.º 5
0
def tf_computation_serializer(parameter_type: Optional[computation_types.Type],
                              context_stack):
    """Serializes a TF computation with a given parameter type.

  Args:
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `computation_types.Type`.
    context_stack: The context stack to use.

  Yields:
    The first yielded value will be a Python object (such as a dataset,
    a placeholder, or a `structure.Struct`) to be passed to the function to
    serialize. The result of the function should then be passed to the
    following `send` call.
    The next yielded value will be
    a tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'arg', parameter_type, graph)
        else:
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                result = yield parameter_value
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding,
                               initialize_op=init_op_name)
    yield pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature
Ejemplo n.º 6
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
    """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(target, types.FunctionType)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    parameter_type = computation_types.to_type(parameter_type)
    signature = function_utils.get_signature(target)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            if len(signature.parameters) != 1:
                raise ValueError(
                    'Expected the target to declare exactly one parameter, found {!r}.'
                    .format(signature.parameters))
            parameter_name = next(iter(signature.parameters))
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                parameter_name, parameter_type, graph)
        else:
            if signature.parameters:
                raise ValueError(
                    'Expected the target to declare no parameters, found {!r}.'
                    .format(signature.parameters))
            parameter_value = None
            parameter_binding = None
        context = tensorflow_computation_context.TensorFlowComputationContext(
            graph)
        with context_stack.install(context):
            with variable_utils.record_variable_creation_scope(
            ) as all_variables:
                if parameter_value is not None:
                    result = target(parameter_value)
                else:
                    result = target()
            initializer_ops = []
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                initializer_ops.append(
                    tf.compat.v1.initializers.variables(all_variables,
                                                        name=name))
            initializer_ops.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
            if initializer_ops:
                # Before running the main new init op, run any initializers for sub-
                # computations from context.init_ops. Variables from import_graph_def
                # will not make it into the global collections, and so will not be
                # initialized without this code path.
                with tf.compat.v1.control_dependencies(context.init_ops):
                    init_op_name = tf.group(*initializer_ops,
                                            name='grouped_initializers').name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)

    # WARNING: we do not really want to be modifying the graph here if we can
    # avoid it. This is purely to work around performance issues uncovered with
    # the non-standard usage of Tensorflow and have been discussed with the
    # Tensorflow core team before being added.
    clean_graph_def = _clean_graph_def(graph.as_graph_def())
    tensorflow = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(clean_graph_def),
        parameter=parameter_binding,
        result=result_binding,
        initialize_op=init_op_name)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow), type_signature