def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
    """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    parameter_type = type_signature

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
def create_indexing_operator(
    operand_type: computation_types.TensorType,
    index_type: computation_types.TensorType,
) -> ComputationProtoAndType:
    """Returns a tensorflow computation computing an indexing operation."""
    operand_type.check_tensor()
    index_type.check_tensor()
    if index_type.shape.rank != 0:
        raise TypeError(
            f'Expected index type to be a scalar, found {index_type}.')
    with tf.Graph().as_default() as graph:
        operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph(
            'indexing_operand', operand_type, graph)
        index_value, index_binding = tensorflow_utils.stamp_parameter_in_graph(
            'index', index_type, graph)
        result_value = tf.gather(operand_value, index_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)
    type_signature = computation_types.FunctionType(
        computation_types.StructType((operand_type, index_type)), result_type)
    parameter_binding = pb.TensorFlow.Binding(
        struct=pb.TensorFlow.StructBinding(
            element=[operand_binding, index_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Example #3
0
def create_replicate_input(type_spec, count: int) -> pb.Computation:
  """Returns a tensorflow computation which returns `count` clones of an input.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_spec` and the length of the result is `count`.

  Args:
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type`.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_spec` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_spec = computation_types.to_type(type_spec)
  type_analysis.check_tensorflow_compatible_type(type_spec)
  py_typecheck.check_type(count, int)

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_spec, graph)
    result = [parameter_value] * count
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(type_spec, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
Example #4
0
  def test_stateful_partitioned_call_nodes(self):

    with tf.Graph().as_default() as graph:
      v = tf.Variable(0)

      @tf.function
      def test():
        return v.assign_add(1)

      result_type, result_binding = tensorflow_utils.capture_result_from_graph(
          test(), graph)

    function_type = computation_types.FunctionType(None, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    proto = pb.Computation(
        type=serialized_function_type,
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=None,
            result=result_binding))

    self.assertCallOpsGrapplerNotDisabled(proto)
    transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
        proto)
    self.assertCallOpsGrapplerDisabled(transformed_proto)
def create_computation_for_py_fn(
        fn: types.FunctionType,
        parameter_type: Optional[computation_types.Type]) -> ProtoAndType:
    """Returns a tensorflow computation returning the result of `fn`.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `parameter_type` and `U` is the type returned by `fn`.

  Args:
    fn: A Python function.
    parameter_type: A `computation_types.Type` or `None`.
  """
    py_typecheck.check_type(fn, types.FunctionType)
    if parameter_type is not None:
        py_typecheck.check_type(parameter_type, computation_types.Type)

    with tf.Graph().as_default() as graph:
        if parameter_type is not None:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', parameter_type, graph)
            result = fn(parameter_value)
        else:
            parameter_binding = None
            result = fn()
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Example #6
0
def create_identity(type_spec) -> pb.Computation:
  """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_spec`.

  Args:
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type`.

  Raises:
    TypeError: If `type_spec` contains any types which cannot appear in
      TensorFlow bindings.
  """
  type_spec = computation_types.to_type(type_spec)
  type_analysis.check_tensorflow_compatible_type(type_spec)

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_spec, graph)
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        parameter_value, graph)

  type_signature = computation_types.FunctionType(type_spec, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
Example #7
0
    def test_gets_all_explicit_placement(self):

        with tf.Graph().as_default() as g:
            with tf.device('/cpu:0'):
                a = tf.constant(0)
                b = tf.constant(1)
                c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        device_placements = building_block_analysis.get_device_placement_in(
            building_block)
        all_device_placements = list(sorted(device_placements.keys()))
        # Expect two placements, the explicit 'cpu' from above, and the empty
        # placement of the `tf.identity` op add to the captured result.
        self.assertLen(all_device_placements, 2)
        self.assertEqual('', sorted(all_device_placements)[0])
        self.assertIn('CPU', sorted(all_device_placements)[1])
        self.assertGreater(device_placements[all_device_placements[1]], 0)
Example #8
0
    def test_counts_correct_variables_with_function(self):
        @tf.function
        def add_one(x):
            with tf.init_scope():
                y = tf.Variable(1)
            return x + y

        with tf.Graph().as_default() as graph:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', tf.int32, graph)
            result = add_one(add_one(parameter_value))

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)
        type_signature = computation_types.FunctionType(tf.int32, result_type)
        tensorflow = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=parameter_binding,
            result=result_binding)
        proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            tensorflow=tensorflow)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)

        tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
            building_block)

        self.assertEqual(tf_vars_in_graph, 1)
Example #9
0
    def test_gets_some_explicit_some_none_placement(self):
        with tf.Graph().as_default() as g:
            with tf.device('/cpu:0'):
                a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        device_placements = building_block_analysis.get_device_placement_in(
            building_block)
        all_device_placements = list(device_placements.keys())
        self.assertLen(all_device_placements, 2)
        if all_device_placements[0]:
            self.assertIn('CPU', all_device_placements[0])
            self.assertEqual('', all_device_placements[1])
        else:
            self.assertIn('CPU', all_device_placements[1])
            self.assertEqual('', all_device_placements[0])
        self.assertGreater(device_placements[all_device_placements[0]], 0)
        self.assertGreater(device_placements[all_device_placements[1]], 0)
Example #10
0
def create_dummy_computation_tensorflow_add():
    """Returns a tensorflow computation and type.

  `(<float32,float32> -> float32)`
  """
    type_spec = tf.float32

    with tf.Graph().as_default() as graph:
        parameter_1_value, parameter_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', type_spec, graph)
        parameter_2_value, parameter_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', type_spec, graph)
        result_value = tf.add(parameter_1_value, parameter_2_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    parameter_type = computation_types.StructType([type_spec, type_spec])
    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    struct_binding = pb.TensorFlow.StructBinding(
        element=[parameter_1_binding, parameter_2_binding])
    parameter_binding = pb.TensorFlow.Binding(struct=struct_binding)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Example #11
0
    def test_invalid_ops(self):
        @tf.function
        def test():
            return tf.constant(1)

        with tf.Graph().as_default() as graph:
            result_type, result_binding = tensorflow_utils.capture_result_from_graph(
                test(), graph)

        function_type = computation_types.FunctionType(None, result_type)
        serialized_function_type = type_serialization.serialize_type(
            function_type)
        proto = computation_pb2.Computation(
            type=serialized_function_type,
            tensorflow=computation_pb2.TensorFlow(
                graph_def=serialization_utils.pack_graph_def(
                    graph.as_graph_def()),
                parameter=None,
                result=result_binding))

        disallowed_op_names = frozenset(['Const'])
        with self.assertRaises(tensorflow_computation_transformations.
                               DisallowedOpInTensorFlowComputationError):
            tensorflow_computation_transformations.check_no_disallowed_ops(
                proto, disallowed_op_names)
def create_dummy_computation_tensorflow_tuple(value=10.0):
    """Returns a tensorflow computation and type.

  `( -> <('a', T), ('b', T), ('c', T)>)`

  Args:
    value: An optional integer value.
  """

    with tf.Graph().as_default() as graph:
        names = ['a', 'b', 'c']
        result = anonymous_tuple.AnonymousTuple(
            (n, tf.constant(value)) for n in names)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Example #13
0
def create_replicate_input(type_signature: computation_types.Type,
                           count: int) -> ProtoAndType:
  """Returns a tensorflow computation returning `count` copies of its argument.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_signature` and the length of the result is `count`.

  Args:
    type_signature: A `computation_types.Type` to replicate.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  py_typecheck.check_type(count, int)
  parameter_type = type_signature

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', parameter_type, graph)
    result = [parameter_value] * count
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
def _pack_noarg_graph(graph_def, return_type, result_binding):
  packed_graph_def = serialization_utils.pack_graph_def(graph_def)
  function_type = computation_types.FunctionType(None, return_type)
  proto = pb.Computation(
      type=type_serialization.serialize_type(function_type),
      tensorflow=pb.TensorFlow(
          graph_def=packed_graph_def, parameter=None, result=result_binding))
  building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
  return building_block
Example #15
0
def create_replicate_input(type_signature: computation_types.Type,
                           count: int) -> ComputationProtoAndType:
  """Returns a tensorflow computation returning `count` copies of its argument.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_signature` and the length of the result is `count`.

  Args:
    type_signature: A `computation_types.Type` to replicate.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  py_typecheck.check_type(count, int)
  parameter_type = type_signature
  identity_comp, _ = create_identity(parameter_type)
  # This manual proto manipulation is significantly faster than using TFF's
  # GraphDef serialization for large `count` arguments.
  tensorflow_comp = identity_comp.tensorflow
  single_result_binding = tensorflow_comp.result
  if tensorflow_comp.parameter:
    new_tf_pb = pb.TensorFlow(
        graph_def=tensorflow_comp.graph_def,
        parameter=tensorflow_comp.parameter,
        result=pb.TensorFlow.Binding(
            struct=pb.TensorFlow.StructBinding(
                element=(single_result_binding for _ in range(count)))))
  else:
    new_tf_pb = pb.TensorFlow(
        graph_def=tensorflow_comp.graph_def,
        result=pb.TensorFlow.Binding(
            struct=pb.TensorFlow.StructBinding(
                element=(single_result_binding for _ in range(count)))))
  fn_type = computation_types.FunctionType(
      parameter_type,
      computation_types.StructType([(None, parameter_type) for _ in range(count)
                                   ]))
  return _tensorflow_comp(new_tf_pb, fn_type)
def wrap_graph_parameter_as_tuple(comp, name=None):
    """Wraps the parameter of `comp` in a tuple binding.

  `wrap_graph_parameter_as_tuple` is intended as a preprocessing step
  to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type`
  can
  make the assumption that its argument `comp` always has a tuple binding,
  instead of dealing with the possibility of an unwrapped tensor or sequence
  binding.

  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation` whose
      parameter we wish to wrap in a tuple binding.
    name: Optional string argument, the name to assign to the element type in
      the constructed tuple. Defaults to `None`.

  Returns:
    A transformed version of comp representing exactly the same computation,
    but accepting a tuple containing one element--the parameter of `comp`.

  Raises:
    TypeError: If `comp` is not a
      `computation_building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    if name is not None:
        py_typecheck.check_type(name, six.string_types)
    proto = comp.proto
    proto_type = type_serialization.deserialize_type(proto.type)

    parameter_binding = [proto.tensorflow.parameter]
    parameter_type_list = [(name, proto_type.parameter)]
    new_parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding))

    new_function_type = computation_types.FunctionType(parameter_type_list,
                                                       proto_type.result)
    serialized_type = type_serialization.serialize_type(new_function_type)

    input_padded_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def,
                                 initialize_op=proto.tensorflow.initialize_op,
                                 parameter=new_parameter_binding,
                                 result=proto.tensorflow.result))

    return computation_building_blocks.CompiledComputation(input_padded_proto)
Example #17
0
def create_empty_tuple() -> ProtoAndType:
  """Returns a tensorflow computation returning an empty tuple.

  The returned computation has the type signature `( -> <>)`.
  """

  with tf.Graph().as_default() as graph:
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        structure.Struct([]), graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Example #18
0
def create_dummy_computation_tensorflow_empty():
    """Returns a tensorflow computation and type `( -> <>)`."""

    with tf.Graph().as_default() as graph:
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            [], graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Example #19
0
def disable_grappler_for_partitioned_calls(proto):
    """Disables grappler for `PartitionedCall` and `StatefulPartitionedCall` nodes in the graph.

  TensorFlow serializes a `ConfigProto` into `PartitionedCall` and
  `StatefulPartitionedCall` the `config_proto` `attr` of graph nodes. This
  overrides any session config that might disable runtime grappler. The disable
  grappler for these nodes as well, this function overwrites the serialized
  configproto, setting the `disable_meta_optimizer` field to `True.

  Args:
    proto: Instance of `computation_pb2.Computation` with the `tensorflow` field
      populated.

  Returns:
    A transformed instance of `computation_pb2.Computation` with a `tensorflow`
    field.
  """
    py_typecheck.check_type(proto, computation_pb2.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                        'protos of the "tensorflow" variety; you have passed '
                        'one of variety {}.'.format(computation_oneof))
    original_tf = proto.tensorflow
    graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def)
    all_nodes = itertools.chain(
        graph_def.node, *[f.node_def for f in graph_def.library.function])
    for node in all_nodes:
        if node.op not in CALL_OPS:
            continue
        attr_str = node.attr.get('config_proto')
        if attr_str is None:
            config_proto = tf.compat.v1.ConfigProto()
        else:
            config_proto = tf.compat.v1.ConfigProto.FromString(attr_str.s)
        config_proto.graph_options.rewrite_options.disable_meta_optimizer = True
        attr_str.s = config_proto.SerializeToString(deterministic=True)
    tf_block = computation_pb2.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(graph_def),
        initialize_op=original_tf.initialize_op
        if original_tf.initialize_op else None,
        parameter=original_tf.parameter
        if original_tf.HasField('parameter') else None,
        result=original_tf.result)
    new_proto = computation_pb2.Computation(type=proto.type,
                                            tensorflow=tf_block)
    return new_proto
Example #20
0
  def test_get_wrapped_function_from_comp_raises_with_incorrect_binding(self):

    with tf.Graph().as_default() as graph:
      var = tf.Variable(initial_value=0.0, name='var1', import_scope='')
      assign_op = var.assign_add(tf.constant(1.0))
      tf.add(1.0, assign_op)

    result_binding = pb.TensorFlow.Binding(
        tensor=pb.TensorFlow.TensorBinding(tensor_name='Invalid'))
    comp = pb.Computation(
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            result=result_binding))
    with self.assertRaises(TypeError):
      wrapped_fn = eager_tf_executor._get_wrapped_function_from_comp(
          comp, must_pin_function_to_cpu=False, param_type=None, device=None)
      wrapped_fn()
Example #21
0
def prune_tensorflow_proto(proto):
    """Extracts subgraph from `proto` preserving parameter, result and initialize.

  Args:
    proto: Instance of `pb.Computation` of the `tensorflow` variety whose
      `graphdef` attribute we wish to prune of extraneous ops.

  Returns:
    A transformed instance of `pb.Computation` of the `tensorflow` variety,
    whose `graphdef` attribute contains only ops which can reach the
    parameter or result bindings, or initialize op.
  """
    py_typecheck.check_type(proto, pb.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError(
            '`prune_tensorflow_proto` only accepts `Computation` '
            'protos of the \'tensorflow\' variety; you have passed '
            'one of variety {}.'.format(computation_oneof))
    if proto.tensorflow.parameter.WhichOneof('binding'):
        parameter_tensor_names = graph_utils.extract_tensor_names_from_binding(
            proto.tensorflow.parameter)
        parameter_names = [
            ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names
        ]
    else:
        parameter_names = []
    return_tensor_names = graph_utils.extract_tensor_names_from_binding(
        proto.tensorflow.result)
    return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names]
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    init_op_name = proto.tensorflow.initialize_op
    names_to_preserve = parameter_names + return_names
    if init_op_name:
        names_to_preserve.append(init_op_name)
    subgraph_def = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def, names_to_preserve)
    tf_block = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(subgraph_def),
        initialize_op=proto.tensorflow.initialize_op,
        parameter=proto.tensorflow.parameter,
        result=proto.tensorflow.result)
    pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block)
    return pruned_proto
Example #22
0
def create_dummy_computation_tensorflow_identity(type_spec=tf.int32):
    """Returns a tensorflow computation and type `(T -> T)`."""

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'a', type_spec, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)

    type_signature = computation_types.FunctionType(type_spec, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Example #23
0
 def transform(self, comp):
   if not self.should_transform(comp):
     return comp, False
   py_typecheck.check_type(comp, building_blocks.CompiledComputation)
   new_tf_proto = computation_pb2.TensorFlow()
   new_tf_proto.CopyFrom(comp.proto.tensorflow)
   # Important: we must also serialize the type_signature because TFF might
   # produce (<> -> <>) or (<> -> <<>>) functions, which both could be
   # represented as the same graph with a single NoOp node. This can occur
   # particularly in MapReduceForm compiltion for secure_sum intrinsics over
   # empty structures.
   hash_value = hash(
       (comp.type_signature, comp.proto.tensorflow.graph_def.value))
   new_tf_proto.cache_key.id = ctypes.c_uint64(hash_value).value
   new_comp_proto = computation_pb2.Computation(
       type=comp.proto.type, tensorflow=new_tf_proto)
   return building_blocks.CompiledComputation(
       new_comp_proto, type_signature=comp.type_signature), True
def create_empty_tuple() -> pb.Computation:
  """Returns a tensorflow computation returning an empty tuple.

  The returned computation has the type signature `( -> <>)`.
  """

  with tf.Graph().as_default() as graph:
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        [], graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
  def test_counts_no_variables(self):

    with tf.Graph().as_default() as g:
      a = tf.constant(0)
      b = tf.constant(1)
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
        building_block)
    self.assertEqual(tf_vars_in_graph, 0)
def create_unary_operator(
        operator,
        operand_type: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation computing a unary operation.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a value of
  type `T`

  Args:
    operator: A callable taking one argument representing the operation to
      encode For example: `tf.math.abs`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed unary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
    if (operand_type is None
            or not type_analysis.is_generic_op_compatible_type(operand_type)):
        raise TypeError(
            '`operand_type` contains a type other than '
            '`computation_types.TensorType` and `computation_types.StructType`; '
            f'this is disallowed in the generic operators. Got: {operand_type} '
        )
    py_typecheck.check_callable(operator)

    with tf.Graph().as_default() as graph:
        operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', operand_type, graph)
        result_value = operator(operand_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    type_signature = computation_types.FunctionType(operand_type, result_type)
    parameter_binding = operand_binding
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
def _create_proto_with_unnecessary_op():
    parameter_type = tf.int32

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)
        unnecessary_op = tf.constant(0)
        tensorflow_utils.capture_result_from_graph(unnecessary_op, graph)

    function_type = computation_types.FunctionType(parameter_type, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    return pb.Computation(type=serialized_function_type,
                          tensorflow=pb.TensorFlow(
                              graph_def=serialization_utils.pack_graph_def(
                                  graph.as_graph_def()),
                              parameter=parameter_binding,
                              result=result_binding))
    def test_counts_two_variables_correctly(self):

        with tf.Graph().as_default() as g:
            a = tf.Variable(0, name='variable1')
            b = tf.Variable(1, name='variable2')
            c = a + b

        _, result_binding = graph_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = computation_building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_vars_in_graph = computation_building_block_utils.count_tensorflow_variables_in(
            building_block)
        self.assertEqual(tf_vars_in_graph, 2)
Example #29
0
    def test_counts_correct_number_of_ops_simple_case(self):

        with tf.Graph().as_default() as g:
            a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)
        # Expect 4 ops: two constants, one addition, and an identity on the result.
        self.assertEqual(tf_ops_in_graph, 4)
def create_whimsy_computation_tensorflow_tuple():
    """Returns a tensorflow computation and type.

  `( -> <('a', float32), ('b', float32), ('c', float32)>)`
  """
    value = 10.0

    with tf.Graph().as_default() as graph:
        names = ['a', 'b', 'c']
        result = structure.Struct((n, tf.constant(value)) for n in names)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature