def _compiled_comp_equal(comp_1, comp_2):
    """Returns `True` iff the computations are entirely identical.

  Args:
    comp_1: A `building_blocks.CompiledComputation` to test.
    comp_2: A `building_blocks.CompiledComputation` to test.

  Raises:
    TypeError: if `comp_1` or `comp_2` is not a
      `building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp_1, building_blocks.CompiledComputation)
    py_typecheck.check_type(comp_2, building_blocks.CompiledComputation)

    tensorflow_1 = comp_1.proto.tensorflow
    tensorflow_2 = comp_2.proto.tensorflow
    if tensorflow_1.initialize_op != tensorflow_2.initialize_op:
        return False
    if tensorflow_1.parameter != tensorflow_2.parameter:
        return False
    if tensorflow_1.result != tensorflow_2.result:
        return False

    graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def)
    graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def)
    # TODO(b/174605105): We prefer to mitigate nans comparing unequal for now,
    # given the severity of TFF's failure to handle this violation of its
    # assumption that trees_equal is an equivalence relation on its ASTs. But this
    # is not a long-term solution. To replace with more legitimate proto
    # comparison which treats nans as equal.
    return graphdef_1.SerializeToString(
        deterministic=True) == graphdef_2.SerializeToString(deterministic=True)
Exemple #2
0
def _compiled_comp_equal(comp_1, comp_2):
    """Returns `True` iff the computations are entirely identical.

  Args:
    comp_1: A `building_blocks.CompiledComputation` to test.
    comp_2: A `building_blocks.CompiledComputation` to test.

  Raises:
    TypeError: if `comp_1` or `comp_2` is not a
      `building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp_1, building_blocks.CompiledComputation)
    py_typecheck.check_type(comp_2, building_blocks.CompiledComputation)

    tensorflow_1 = comp_1.proto.tensorflow
    tensorflow_2 = comp_2.proto.tensorflow
    if tensorflow_1.initialize_op != tensorflow_2.initialize_op:
        return False
    if tensorflow_1.parameter != tensorflow_2.parameter:
        return False
    if tensorflow_1.result != tensorflow_2.result:
        return False

    graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def)
    graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def)
    return graphdef_1 == graphdef_2
Exemple #3
0
def _compiled_comp_equal(comp_1, comp_2):
    """Returns `True` iff the computations are entirely identical.

  Args:
    comp_1: A `building_blocks.CompiledComputation` to test.
    comp_2: A `building_blocks.CompiledComputation` to test.

  Raises:
    TypeError: if `comp_1` or `comp_2` is not a
      `building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp_1, building_blocks.CompiledComputation)
    py_typecheck.check_type(comp_2, building_blocks.CompiledComputation)

    tensorflow_1 = comp_1.proto.tensorflow
    tensorflow_2 = comp_2.proto.tensorflow
    if tensorflow_1.initialize_op != tensorflow_2.initialize_op:
        return False
    if tensorflow_1.parameter != tensorflow_2.parameter:
        return False
    if tensorflow_1.result != tensorflow_2.result:
        return False

    graphdef_1 = serialization_utils.unpack_graph_def(tensorflow_1.graph_def)
    graphdef_2 = serialization_utils.unpack_graph_def(tensorflow_2.graph_def)
    # TODO(b/174605105): Remove this gating when TFF updates its TensorFlow
    # dependency.
    if version_check.is_tensorflow_version_newer('2.6.0', tf):
        return tf.__internal__.graph_util.graph_defs_equal(
            graphdef_1, graphdef_2, treat_nan_as_equal=True)
    else:
        return graphdef_1.SerializeToString(
            deterministic=True) == graphdef_2.SerializeToString(
                deterministic=True)
    def test_serialize_tensorflow_with_table_no_variables(self):
        def table_lookup(word):
            table = tf.lookup.StaticVocabularyTable(
                tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'],
                                                    np.arange(3,
                                                              dtype=np.int64)),
                num_oov_buckets=1)
            return table.lookup(word)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            table_lookup,
            computation_types.TensorType(dtype=tf.string, shape=(None, )),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(string[?] -> int64[?])')
        self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')

        with tf.Graph().as_default() as g:
            tf.import_graph_def(serialization_utils.unpack_graph_def(
                comp.tensorflow.graph_def),
                                name='')
        with tf.compat.v1.Session(graph=g) as sess:
            sess.run(fetches=comp.tensorflow.initialize_op)
            results = sess.run(
                fetches=comp.tensorflow.result.tensor.tensor_name,
                feed_dict={
                    comp.tensorflow.parameter.tensor.tensor_name:
                    ['b', 'c', 'a']
                })
        self.assertAllEqual(results, [1, 2, 0])
    def test_tf_wrapper_with_tf_add(self):
        foo = computation_wrapper_instances.tensorflow_wrapper(
            tf.add, (tf.int32, tf.int32))
        self.assertEqual(str(foo.type_signature), '(<int32,int32> -> int32)')

        # TODO(b/113112885): Remove this protected member access as noted above.
        comp = foo._computation_proto  # pylint: disable=protected-access

        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        x = tf.compat.v1.placeholder(tf.int32)
        y = tf.compat.v1.placeholder(tf.int32)
        result = tf.import_graph_def(
            serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), {
                comp.tensorflow.parameter.tuple.element[0].tensor.tensor_name:
                x,
                comp.tensorflow.parameter.tuple.element[1].tensor.tensor_name:
                y
            }, [comp.tensorflow.result.tensor.tensor_name])
        with self.session() as sess:

            def _run(n):
                return sess.run(result, feed_dict={x: n, y: 3})

            results = [_run(n) for n in [1, 20, 5, 10, 30]]
            self.assertEqual(results, [[4], [23], [8], [13], [33]])
    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                len(input_tensor_names), len(args)))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def),
                input_map=dict(list(zip(input_tensor_names, args))),
                return_elements=output_tensor_names)

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device):
                return _import_fn()
        else:
            return _import_fn()
    def test_serialize_tensorflow_with_dataset_not_optimized(self):
        @tf.function
        def test_foo(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        def legacy_dataset_reducer_example(ds):
            return test_foo(ds)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)

        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        self.assertGraphDoesNotContainOps(
            graph_def,
            ['OptimizeDataset', 'OptimizeDatasetV2', 'ModelDataset'])
        results = tf.compat.v1.Session().run(
            tf.import_graph_def(
                graph_def, {
                    comp.tensorflow.parameter.sequence.variant_tensor_name:
                    tf.data.experimental.to_variant(parameter)
                }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
def get_device_placement_in(comp):
    """Gets counter of device placement for tensorflow compuation `comp`."""
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    if (not isinstance(comp, building_blocks.CompiledComputation)) or (
            comp.proto.WhichOneof('computation') != 'tensorflow'):
        raise ValueError(
            'Please pass a '
            '`building_blocks.CompiledComputation` of the '
            '`tensorflow` variety to `get_device_placement_in`. (Got '
            'a [{t}]).'.format(t=type(comp)))
    graph_def = serialization_utils.unpack_graph_def(
        comp.proto.tensorflow.graph_def)

    counter = collections.Counter()

    def _populate_counter_in_function_lib(func_library):
        for graph_func in func_library.function:
            counter.update(node.device for node in graph_func.node_def)
        for graph_func in func_library.gradient:
            counter.update(node.device for node in graph_func.node_def)

    counter.update(node.device for node in graph_def.node)
    _populate_counter_in_function_lib(graph_def.library)

    return counter
Exemple #9
0
    def function_to_wrap():
        """No-arg function to import graph def.

    We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid
    the leftover placeholders that can result from binding arguments to the
    imported graphdef via `input_map`. The correct signature will be added to
    this function later, via the `prune` call below.

    Returns:
      Result of importing graphdef backing `comp`.
    """
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        # TODO(b/159180073): clean raise after fixing dataset reduce.
        _check_dataset_reduce_in_multi_gpu(graph_def)

        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def), name='')

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device.name):
                return _import_fn()
        else:
            return _import_fn()
 def test_pack_unpack_roundtrip(self):
     with tf.Graph().as_default() as g:
         tf.constant(1.0)
     input_value = g.as_graph_def()
     any_pb = serialization_utils.pack_graph_def(input_value)
     output_value = serialization_utils.unpack_graph_def(any_pb)
     self.assertEqual(input_value, output_value)
Exemple #11
0
def _unpack_proto_into_graph_spec(tf_block_proto):
    """Packs a TF proto into a `graph_merge.GraphSpec`.

  Args:
    tf_block_proto: Instance of `computation_pb2.Computation` with `tensorflow`
      `computation` attribute.

  Returns:
    Instance of `graph_merge.GraphSpec` containing Python representations of
    the information present in `tf_block_proto`.
  """
    graph = serialization_utils.unpack_graph_def(
        tf_block_proto.tensorflow.graph_def)
    graph_init_op_name = tf_block_proto.tensorflow.initialize_op
    if not graph_init_op_name:
        graph_init_op_name = None
    graph_parameter_binding = tf_block_proto.tensorflow.parameter
    graph_result_binding = tf_block_proto.tensorflow.result

    if graph_parameter_binding.WhichOneof('binding') is not None:
        graph_parameter_list = graph_utils.extract_tensor_names_from_binding(
            graph_parameter_binding)
    else:
        graph_parameter_list = []
    graph_result_list = graph_utils.extract_tensor_names_from_binding(
        graph_result_binding)
    return graph_merge.GraphSpec(graph, graph_init_op_name,
                                 graph_parameter_list, graph_result_list)
def count_tensorflow_variables_in(comp):
    """Counts TF Variables in `comp` if `comp` is a TF block."""
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    if (not isinstance(comp, building_blocks.CompiledComputation)) or (
            comp.proto.WhichOneof('computation') != 'tensorflow'):
        raise ValueError(
            'Please pass a '
            '`building_blocks.CompiledComputation` of the '
            '`tensorflow` variety to `count_tensorflow_variables_in`.')
    graph_def = serialization_utils.unpack_graph_def(
        comp.proto.tensorflow.graph_def)

    def _node_is_variable(node):
        # TODO(b/137887596): Follow up on ways to count Variables on the GraphDef
        # level.
        op_name = str(node.op).lower()
        return ((op_name.startswith('variable')
                 and op_name not in ['variableshape'])
                or op_name == 'varhandleop')

    def _count_vars_in_function_lib(func_library):
        total_nodes = 0
        for graph_func in func_library.function:
            total_nodes += sum(
                _node_is_variable(node) for node in graph_func.node_def)
        return total_nodes

    return (sum(_node_is_variable(node) for node in graph_def.node) +
            _count_vars_in_function_lib(graph_def.library))
Exemple #13
0
 def test_serialize_tensorflow_with_no_parameter(self):
   tf_proto, _ = self.assert_serializes(lambda _: tf.constant(99), None,
                                        '( -> int32)')
   results = tf.compat.v1.Session().run(
       tf.graph_util.import_graph_def(
           serialization_utils.unpack_graph_def(tf_proto.graph_def), None,
           [tf_proto.result.tensor.tensor_name]))
   self.assertEqual(results, [99])
Exemple #14
0
 def function_to_wrap(*args):
   if len(args) != len(input_tensor_names):
     raise RuntimeError('Expected {} arguments, found {}.'.format(
         str(len(input_tensor_names)), str(len(args))))
   graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
   return tf.import_graph_def(
       graph_def,
       input_map=dict(zip(input_tensor_names, args)),
       return_elements=output_tensor_names)
Exemple #15
0
def _ensure_comp_runtime_compatible(comp: pb.Computation) -> pb.Computation:
  """Ensures `comp` is compatible with eager runtime backing EagerExecutor."""
  original_tf = comp.tensorflow
  graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def)
  # TODO(b/159180073): clean raise after fixing dataset reduce.
  num_gpu_devices = len(tf.config.list_logical_devices('GPU'))
  if num_gpu_devices > 1:
    _check_dataset_reduce_for_multi_gpu(graph_def)

  return comp
 def test_serialize_tensorflow_with_no_parameter(self):
   comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda: tf.constant(99), None, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '( -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   results = tf.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           None, [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [99])
Exemple #17
0
def count_tensorflow_ops_in(comp):
    """Counts TF ops in `comp` if `comp` is a TF block."""
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    if (not isinstance(comp, building_blocks.CompiledComputation)) or (
            comp.proto.WhichOneof('computation') != 'tensorflow'):
        raise ValueError('Please pass a '
                         '`building_blocks.CompiledComputation` of the '
                         '`tensorflow` variety to `count_tensorflow_ops_in`.')
    graph_def = serialization_utils.unpack_graph_def(
        comp.proto.tensorflow.graph_def)
    return len(graph_def.node)
Exemple #18
0
 def test_serialize_tensorflow_with_simple_add_three_lambda(self):
   tf_proto, _ = self.assert_serializes(lambda x: x + 3,
                                        computation_types.TensorType(tf.int32),
                                        '(int32 -> int32)')
   parameter = tf.constant(1000)
   results = tf.compat.v1.Session().run(
       tf.graph_util.import_graph_def(
           serialization_utils.unpack_graph_def(tf_proto.graph_def),
           {tf_proto.parameter.tensor.tensor_name: parameter},
           [tf_proto.result.tensor.tensor_name]))
   self.assertEqual(results, [1003])
Exemple #19
0
def _extract_call_ops(comp: pb.Computation) -> Iterator[tf.compat.v1.NodeDef]:
  computation_oneof = comp.WhichOneof('computation')
  if computation_oneof != 'tensorflow':
    raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                    'protos of the "tensorflow" variety; you have passed '
                    'one of variety {}.'.format(computation_oneof))
  graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
  all_nodes = itertools.chain(graph_def.node,
                              *[f.node_def for f in graph_def.library.function])
  for node in all_nodes:
    if node.op in tensorflow_computation_transformations.CALL_OPS:
      yield node
 def function_to_wrap(*args):  # pylint: disable=missing-docstring
   if len(args) != len(input_tensor_names):
     raise RuntimeError('Expected {} arguments, found {}.'.format(
         str(len(input_tensor_names)), str(len(args))))
   graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
   init_op = comp.tensorflow.initialize_op
   if init_op:
     graph_def = graph_utils.add_control_deps_for_init_op(graph_def, init_op)
   return tf.import_graph_def(
       graph_merge.uniquify_shared_names(graph_def),
       input_map=dict(zip(input_tensor_names, args)),
       return_elements=output_tensor_names)
 def test_serialize_tensorflow_with_simple_add_three_lambda(self):
   comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   parameter = tf.constant(1000)
   results = tf.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           {comp.tensorflow.parameter.tensor.tensor_name: parameter},
           [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [1003])
Exemple #22
0
def _check_ops(proto: computation_pb2.Computation,
               allowed_op_names: Optional[FrozenSet[str]] = None,
               disallowed_op_names: Optional[FrozenSet[str]] = None):
    """Checks the ops in the TensorFlow computation.

  If allowed_op_names is specified, then _check_ops checks the incoming proto
  contains only ops in the set. On the other hand, if disallowed_op_names is
  specified, then _check_ops checks the proto contains no ops contained in the
  set. One of the two op set arguments must be non-empty, and if both are, then
  allowed_op_names takes precedent.

  Args:
    proto: Instance of `computation_pb2.Computation` with the `tensorflow` field
      populated.
    allowed_op_names: Set of allowed op names.
    disallowed_op_names: Set of disallowed op names.

  Raises:
    DisallowedOpInTensorFlowComputationError: If the computation contains a
      disallowed op.
    RuntimeError: If both allowed_op_names and disallowed_op_names are empty.
  """
    py_typecheck.check_type(proto, computation_pb2.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                        'protos of the "tensorflow" variety; you have passed '
                        'one of variety {}.'.format(computation_oneof))
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    all_nodes = itertools.chain(
        graph_def.node, *[f.node_def for f in graph_def.library.function])
    found_disallowed_op_names = set()

    if allowed_op_names:
        for node in all_nodes:
            if node.op not in allowed_op_names:
                found_disallowed_op_names.add(node.op)
    elif disallowed_op_names:
        for node in all_nodes:
            if node.op in disallowed_op_names:
                found_disallowed_op_names.add(node.op)
    else:
        raise RuntimeError(
            'One of allowed_op_names or disallowed_op_names must be non-empty')

    if found_disallowed_op_names:
        found_disallowed_op_names_str = ', '.join(found_disallowed_op_names)
        raise DisallowedOpInTensorFlowComputationError(
            f'Found disallowed ops: {found_disallowed_op_names_str}')
def count_tensorflow_variables_in(comp):
  """Counts TF Variables in `comp` if `comp` is a TF block."""
  py_typecheck.check_type(comp,
                          computation_building_blocks.ComputationBuildingBlock)
  if (not isinstance(comp, computation_building_blocks.CompiledComputation)
     ) or (comp.proto.WhichOneof('computation') != 'tensorflow'):
    raise ValueError('Please pass a '
                     '`computation_building_blocks.CompiledComputation` of the '
                     '`tensorflow` variety to `count_tensorflow_variables_in`.')
  graph_def = serialization_utils.unpack_graph_def(
      comp.proto.tensorflow.graph_def)
  # TODO(b/137887596): Follow up on ways to count Variables on the GraphDef
  # level.
  return len([x for x in graph_def.node if 'variable' in str(x.op).lower()])
Exemple #24
0
  def test_serialize_tensorflow_with_data_set_sum_lambda(self):

    def _legacy_dataset_reducer_example(ds):
      return ds.reduce(np.int64(0), lambda x, y: x + y)

    tf_proto, _ = self.assert_serializes(
        _legacy_dataset_reducer_example,
        computation_types.SequenceType(tf.int64), '(int64* -> int64)')
    parameter = tf.data.Dataset.range(5)
    results = tf.compat.v1.Session().run(
        tf.graph_util.import_graph_def(
            serialization_utils.unpack_graph_def(tf_proto.graph_def), {
                tf_proto.parameter.sequence.variant_tensor_name:
                    tf.data.experimental.to_variant(parameter)
            }, [tf_proto.result.tensor.tensor_name]))
    self.assertEqual(results, [10])
Exemple #25
0
def disable_grappler_for_partitioned_calls(proto):
    """Disables grappler for `PartitionedCall` and `StatefulPartitionedCall` nodes in the graph.

  TensorFlow serializes a `ConfigProto` into `PartitionedCall` and
  `StatefulPartitionedCall` the `config_proto` `attr` of graph nodes. This
  overrides any session config that might disable runtime grappler. The disable
  grappler for these nodes as well, this function overwrites the serialized
  configproto, setting the `disable_meta_optimizer` field to `True.

  Args:
    proto: Instance of `computation_pb2.Computation` with the `tensorflow` field
      populated.

  Returns:
    A transformed instance of `computation_pb2.Computation` with a `tensorflow`
    field.
  """
    py_typecheck.check_type(proto, computation_pb2.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                        'protos of the "tensorflow" variety; you have passed '
                        'one of variety {}.'.format(computation_oneof))
    original_tf = proto.tensorflow
    graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def)
    all_nodes = itertools.chain(
        graph_def.node, *[f.node_def for f in graph_def.library.function])
    for node in all_nodes:
        if node.op not in CALL_OPS:
            continue
        attr_str = node.attr.get('config_proto')
        if attr_str is None:
            config_proto = tf.compat.v1.ConfigProto()
        else:
            config_proto = tf.compat.v1.ConfigProto.FromString(attr_str.s)
        config_proto.graph_options.rewrite_options.disable_meta_optimizer = True
        attr_str.s = config_proto.SerializeToString(deterministic=True)
    tf_block = computation_pb2.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(graph_def),
        initialize_op=original_tf.initialize_op
        if original_tf.initialize_op else None,
        parameter=original_tf.parameter
        if original_tf.HasField('parameter') else None,
        result=original_tf.result)
    new_proto = computation_pb2.Computation(type=proto.type,
                                            tensorflow=tf_block)
    return new_proto
Exemple #26
0
 def function_to_wrap(*args):  # pylint: disable=missing-docstring
     if len(args) != len(input_tensor_names):
         raise RuntimeError('Expected {} arguments, found {}.'.format(
             str(len(input_tensor_names)), str(len(args))))
     graph_def = serialization_utils.unpack_graph_def(
         comp.tensorflow.graph_def)
     init_op = comp.tensorflow.initialize_op
     init_names = [init_op] if init_op else []
     returned_elements = tf.import_graph_def(
         graph_merge.uniquify_shared_names(graph_def),
         input_map=dict(zip(input_tensor_names, args)),
         return_elements=output_tensor_names + init_names)
     if init_names:
         with tf.control_dependencies([returned_elements[-1]]):
             return [tf.identity(x) for x in returned_elements[0:-1]]
     else:
         return returned_elements
Exemple #27
0
def prune_tensorflow_proto(proto):
    """Extracts subgraph from `proto` preserving parameter, result and initialize.

  Args:
    proto: Instance of `pb.Computation` of the `tensorflow` variety whose
      `graphdef` attribute we wish to prune of extraneous ops.

  Returns:
    A transformed instance of `pb.Computation` of the `tensorflow` variety,
    whose `graphdef` attribute contains only ops which can reach the
    parameter or result bindings, or initialize op.
  """
    py_typecheck.check_type(proto, pb.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError(
            '`prune_tensorflow_proto` only accepts `Computation` '
            'protos of the \'tensorflow\' variety; you have passed '
            'one of variety {}.'.format(computation_oneof))
    if proto.tensorflow.parameter.WhichOneof('binding'):
        parameter_tensor_names = graph_utils.extract_tensor_names_from_binding(
            proto.tensorflow.parameter)
        parameter_names = [
            ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names
        ]
    else:
        parameter_names = []
    return_tensor_names = graph_utils.extract_tensor_names_from_binding(
        proto.tensorflow.result)
    return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names]
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    init_op_name = proto.tensorflow.initialize_op
    names_to_preserve = parameter_names + return_names
    if init_op_name:
        names_to_preserve.append(init_op_name)
    subgraph_def = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def, names_to_preserve)
    tf_block = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(subgraph_def),
        initialize_op=proto.tensorflow.initialize_op,
        parameter=proto.tensorflow.parameter,
        result=proto.tensorflow.result)
    pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block)
    return pruned_proto
    def test_serialize_tensorflow_with_data_set_sum_lambda(self):
        def _legacy_dataset_reducer_example(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            _legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)
        results = tf.compat.v1.Session().run(
            tf.import_graph_def(
                serialization_utils.unpack_graph_def(
                    comp.tensorflow.graph_def), {
                        comp.tensorflow.parameter.sequence.variant_tensor_name:
                        tf.data.experimental.to_variant(parameter)
                    }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
    def test_serialize_tensorflow_with_data_set_sum_lambda(self):
        def _legacy_dataset_reducer_example(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            _legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)
        results = tf.Session().run(
            tf.import_graph_def(
                serialization_utils.unpack_graph_def(
                    comp.tensorflow.graph_def),
                {
                    comp.tensorflow.parameter.sequence.iterator_string_handle_name:
                    (parameter.make_one_shot_iterator().string_handle())
                }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
    def test_tf_wrapper_with_one_op_py_fn(self):
        @computation_wrapper_instances.tensorflow_wrapper(tf.int32)
        def foo(x):
            return x > 10

        self.assertEqual(str(foo.type_signature), '(int32 -> bool)')

        # TODO(b/113112885): Remove this protected member access once the part of
        # the infrastructure that deals with invoking functions is present. At this
        # point, extracting the proto from within 'foo' is the only way to test the
        # wrapper works as intended.
        comp = foo._computation_proto  # pylint: disable=protected-access

        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        x = tf.compat.v1.placeholder(tf.int32)
        result = tf.import_graph_def(
            serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
            {comp.tensorflow.parameter.tensor.tensor_name: x},
            [comp.tensorflow.result.tensor.tensor_name])
        self.assertEqual(
            list(tf.compat.v1.Session().run(result, feed_dict={x: n})
                 for n in [1, 20, 5, 10, 30]),
            [[False], [True], [False], [False], [True]])