コード例 #1
0
 def _checked_capture_result(self, result):
     """Returns the captured result type after first verifying the binding."""
     graph = tf.get_default_graph()
     type_spec, binding = graph_utils.capture_result_from_graph(
         result, graph)
     self._assert_binding_matches_type_and_value(binding, type_spec, result,
                                                 graph)
     return type_spec
コード例 #2
0
 def test_extract_tensor_names_from_binding_with_tuple_of_tensors(self):
     with tf.Graph().as_default() as graph:
         _, binding = graph_utils.capture_result_from_graph(
             collections.OrderedDict([('foo', tf.constant(10, name='A')),
                                      ('bar', tf.constant(20, name='B'))]),
             graph)
     result = graph_utils.extract_tensor_names_from_binding(binding)
     self.assertEqual(str(sorted(result)), '[\'A:0\', \'B:0\']')
コード例 #3
0
def _create_two_variable_tensorflow():
    with tf.Graph().as_default() as g:
        a = tf.Variable(0, name='variable1')
        b = tf.Variable(1, name='variable2')
        c = a + b

    result_type, result_binding = graph_utils.capture_result_from_graph(c, g)

    return _pack_noarg_graph(g.as_graph_def(), result_type, result_binding)
コード例 #4
0
    def test_capture_result_with_attrs_of_constants(self):
        @attr.s
        class TestFoo(object):
            x = attr.ib()
            y = attr.ib()

        graph = tf.get_default_graph()
        type_spec, _ = graph_utils.capture_result_from_graph(
            TestFoo(tf.constant(1), tf.constant(True)), graph)
        self.assertEqual(str(type_spec), '<x=int32,y=bool>')
        self.assertIsInstance(
            type_spec, computation_types.NamedTupleTypeWithPyContainerType)
        self.assertIs(
            computation_types.NamedTupleTypeWithPyContainerType.
            get_container_type(type_spec), TestFoo)
コード例 #5
0
    def test_avoids_misdirection_with_name(self):

        with tf.Graph().as_default() as g:
            a = tf.constant(0, name='variable1')
            b = tf.constant(1, name='variable2')
            c = a + b

        _, result_binding = graph_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = computation_building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_vars_in_graph = computation_building_block_utils.count_tensorflow_variables_in(
            building_block)
        self.assertEqual(tf_vars_in_graph, 0)
コード例 #6
0
    def test_counts_correct_number_of_ops_simple_case(self):

        with tf.Graph().as_default() as g:
            a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = graph_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = computation_building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_ops_in_graph = computation_building_block_utils.count_tensorflow_ops_in(
            building_block)
        self.assertEqual(tf_ops_in_graph, 3)
コード例 #7
0
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
    """Serializes the 'target' as a TF computation with a given parameter type.

  See also `serialize_tf2_as_tf_computation` for TensorFlow 2
  serialization.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager
      functions, and eventually aim at full support for and compliance with TF
      2.0. This function is currently required to declare either zero parameters
      if `parameter_type` is `None`, or exactly one parameter if it's not
      `None`.  The nested structure of this parameter must correspond to the
      structure of the 'parameter_type'. In the future, we may support targets
      with multiple args/keyword args (to be documented in the API and
      referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    A tuple of (`pb.Computation`, `tff.Type`), where the computation contains
    the instance with the `pb.TensorFlow` variant set, and the type is an
    instance of `tff.Type`, potentially including Python container annotations,
    for use by TensorFlow computation wrappers.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(target, types.FunctionType)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    parameter_type = computation_types.to_type(parameter_type)
    argspec = inspect.getargspec(target)  # pylint: disable=deprecated-method

    with tf.Graph().as_default() as graph:
        args = []
        if parameter_type is not None:
            if len(argspec.args) != 1:
                raise ValueError(
                    'Expected the target to declare exactly one parameter, '
                    'found {}.'.format(repr(argspec.args)))
            parameter_name = argspec.args[0]
            parameter_value, parameter_binding = graph_utils.stamp_parameter_in_graph(
                parameter_name, parameter_type, graph)
            args.append(parameter_value)
        else:
            if argspec.args:
                raise ValueError(
                    'Expected the target to declare no parameters, found {}.'.
                    format(repr(argspec.args)))
            parameter_binding = None
        context = tf_computation_context.TensorFlowComputationContext(graph)
        with context_stack.install(context):
            result = target(*args)

            # TODO(b/122081673): This needs to change for TF 2.0. We may also
            # want to allow the person creating a tff.tf_computation to specify
            # a different initializer; e.g., if it is known that certain
            # variables will be assigned immediately to arguments of the function,
            # then it is wasteful to initialize them before this.
            #
            # The following is a bit of a work around: the collections below may
            # contain variables more than once, hence we throw into a set. TFF needs
            # to ensure all variables are initialized, but not all variables are
            # always in the collections we expect. tff.learning._KerasModel tries to
            # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into
            # TFF_MODEL_VARIABLES for now.
            all_variables = set(
                tf.compat.v1.global_variables() +
                tf.compat.v1.local_variables() + tf.compat.v1.get_collection(
                    graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE))
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                with tf.control_dependencies(context.init_ops):
                    # Before running the main new init op, run any initializers for sub-
                    # computations from context.init_ops. Variables from import_graph_def
                    # will not make it into the global collections, and so will not be
                    # initialized without this code path.
                    init_op_name = tf.compat.v1.initializers.variables(
                        all_variables, name=name).name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = graph_utils.capture_result_from_graph(
            result, graph)

    annotated_type = computation_types.FunctionType(parameter_type,
                                                    result_type)

    return pb.Computation(type=pb.Type(function=pb.FunctionType(
        parameter=type_serialization.serialize_type(parameter_type),
        result=type_serialization.serialize_type(result_type))),
                          tensorflow=pb.TensorFlow(
                              graph_def=serialization_utils.pack_graph_def(
                                  graph.as_graph_def()),
                              parameter=parameter_binding,
                              result=result_binding,
                              initialize_op=init_op_name)), annotated_type
コード例 #8
0
def deserialize_and_call_tf_computation(computation_proto, arg, graph):
    """Deserializes a TF computation and inserts it into `graph`.

  This method performs an action that can be considered roughly the opposite of
  what `tensorflow_serialization.serialize_py_fn_as_tf_computation` does. At
  the moment, it simply imports the graph in the current context. A future
  implementation may rely on different mechanisms. The caller should not be
  concerned with the specifics of the implementation. At this point, the method
  is expected to only be used within the body of another TF computation (within
  an instance of `tf_computation_context.TensorFlowComputationContext` at the
  top of the stack), and potentially also in certain types of interpreted
  execution contexts (TBD).

  Args:
    computation_proto: An instance of `pb.Computation` with the `computation`
      one of equal to `tensorflow` to be deserialized and called.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.
    graph: The graph to stamp into.

  Returns:
    A tuple (init_op, result) where:
       init_op:  String name of an op to initialize the graph.
       result: The results to be fetched from TensorFlow. Depending on
           the type of the result, this can be `tf.Tensor` or `tf.data.Dataset`
           instances, or a nested structure (such as an
           `anonymous_tuple.AnonymousTuple`).

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If `computation_proto` is not a TensorFlow computation proto.
  """
    py_typecheck.check_type(computation_proto, pb.Computation)
    computation_oneof = computation_proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise ValueError('Expected a TensorFlow computation, got {}.'.format(
            computation_oneof))
    py_typecheck.check_type(graph, tf.Graph)
    with graph.as_default():
        type_spec = type_serialization.deserialize_type(computation_proto.type)
        if type_spec.parameter is None:
            if arg is None:
                input_map = None
            else:
                raise TypeError(
                    'The computation declared no parameters; encountered an unexpected '
                    'argument {}.'.format(str(arg)))
        elif arg is None:
            raise TypeError(
                'The computation declared a parameter of type {}, but the argument '
                'was not supplied.'.format(str(type_spec.parameter)))
        else:
            arg_type, arg_binding = graph_utils.capture_result_from_graph(
                arg, graph)
            if not type_utils.is_assignable_from(type_spec.parameter,
                                                 arg_type):
                raise TypeError(
                    'The computation declared a parameter of type {}, but the argument '
                    'is of a mismatching type {}.'.format(
                        str(type_spec.parameter), str(arg_type)))
            else:
                input_map = {
                    k: graph.get_tensor_by_name(v)
                    for k, v in six.iteritems(
                        graph_utils.compute_map_from_bindings(
                            computation_proto.tensorflow.parameter,
                            arg_binding))
                }
        return_elements = graph_utils.extract_tensor_names_from_binding(
            computation_proto.tensorflow.result)
        orig_init_op_name = computation_proto.tensorflow.initialize_op
        if orig_init_op_name:
            return_elements.append(orig_init_op_name)
        # N. B. Unlike MetaGraphDef, the GraphDef alone contains no information
        # about collections, and hence, when we import a graph with Variables,
        # those Variables are not added to global collections, and hence
        # functions like tf.compat.v1.global_variables_initializers() will not
        # contain their initialization ops.
        output_tensors = tf.import_graph_def(
            serialization_utils.unpack_graph_def(
                computation_proto.tensorflow.graph_def),
            input_map,
            return_elements,
            # N. B. It is very important not to return any names from the original
            # computation_proto.tensorflow.graph_def, those names might or might not
            # be valid in the current graph. Using a different scope makes the graph
            # somewhat more readable, since _N style de-duplication of graph
            # node names is less likely to be needed.
            name='subcomputation')

        output_map = {k: v for k, v in zip(return_elements, output_tensors)}
        new_init_op_name = output_map.pop(orig_init_op_name, None)
        return (new_init_op_name,
                graph_utils.assemble_result_from_graph(
                    type_spec.result, computation_proto.tensorflow.result,
                    output_map))
コード例 #9
0
 def test_capture_result_with_np_ndarray(self):
     with tf.Graph().as_default() as graph:
         type_spec, binding = graph_utils.capture_result_from_graph(
             np.ndarray(shape=(2, 0), dtype=np.int32), graph)
     self._assert_captured_result_eq_dtype(type_spec, binding, 'int32[2,0]')
コード例 #10
0
 def test_capture_result_with_np_bool(self):
     with tf.Graph().as_default() as graph:
         type_spec, binding = graph_utils.capture_result_from_graph(
             np.bool(True), graph)
     self._assert_captured_result_eq_dtype(type_spec, binding, 'bool')
コード例 #11
0
 def test_capture_result_with_np_float64(self):
     with tf.Graph().as_default() as graph:
         type_spec, binding = graph_utils.capture_result_from_graph(
             np.float64(1.0), graph)
     self._assert_captured_result_eq_dtype(type_spec, binding, 'float64')
コード例 #12
0
 def test_capture_result_with_np_int32(self):
     with tf.Graph().as_default() as graph:
         type_spec, binding = graph_utils.capture_result_from_graph(
             np.int32(1), graph)
     self._assert_captured_result_eq_dtype(type_spec, binding, 'int32')
コード例 #13
0
 def test_capture_result_with_string(self):
     with tf.Graph().as_default() as graph:
         type_spec, binding = graph_utils.capture_result_from_graph(
             'a', graph)
     self._assert_captured_result_eq_dtype(type_spec, binding, 'string')