def test_pack_graph_seed_set_raises(self):
     with tf.Graph().as_default() as g:
         tf.random.set_random_seed(1234)
         tf.random.normal([1])
     input_value = g.as_graph_def()
     with self.assertRaisesRegex(ValueError, 'graph-level random seed'):
         serialization_utils.pack_graph_def(input_value)
Exemple #2
0
def create_computation_for_py_fn(
    fn: types.FunctionType,
    parameter_type: Optional[computation_types.Type]) -> ProtoAndType:
  """Returns a tensorflow computation returning the result of `fn`.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `parameter_type` and `U` is the type returned by `fn`.

  Args:
    fn: A Python function.
    parameter_type: A `computation_types.Type` or `None`.
  """
  py_typecheck.check_type(fn, types.FunctionType)
  if parameter_type is not None:
    py_typecheck.check_type(parameter_type, computation_types.Type)

  with tf.Graph().as_default() as graph:
    if parameter_type is not None:
      parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
          'x', parameter_type, graph)
      result = fn(parameter_value)
    else:
      parameter_binding = None
      result = fn()
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Exemple #3
0
    def test_invalid_ops(self):
        @tf.function
        def test():
            return tf.constant(1)

        with tf.Graph().as_default() as graph:
            result_type, result_binding = tensorflow_utils.capture_result_from_graph(
                test(), graph)

        function_type = computation_types.FunctionType(None, result_type)
        serialized_function_type = type_serialization.serialize_type(
            function_type)
        proto = computation_pb2.Computation(
            type=serialized_function_type,
            tensorflow=computation_pb2.TensorFlow(
                graph_def=serialization_utils.pack_graph_def(
                    graph.as_graph_def()),
                parameter=None,
                result=result_binding))

        disallowed_op_names = frozenset(['Const'])
        with self.assertRaises(tensorflow_computation_transformations.
                               DisallowedOpInTensorFlowComputationError):
            tensorflow_computation_transformations.check_no_disallowed_ops(
                proto, disallowed_op_names)
Exemple #4
0
def create_replicate_input(type_signature: computation_types.Type,
                           count: int) -> ProtoAndType:
  """Returns a tensorflow computation returning `count` copies of its argument.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_signature` and the length of the result is `count`.

  Args:
    type_signature: A `computation_types.Type` to replicate.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_analysis.check_tensorflow_compatible_type(type_signature)
  py_typecheck.check_type(count, int)
  parameter_type = type_signature

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', parameter_type, graph)
    result = [parameter_value] * count
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(parameter_type, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Exemple #5
0
    def test_counts_correct_variables_with_function(self):
        @tf.function
        def add_one(x):
            with tf.init_scope():
                y = tf.Variable(1)
            return x + y

        with tf.Graph().as_default() as graph:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', tf.int32, graph)
            result = add_one(add_one(parameter_value))

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)
        type_signature = computation_types.FunctionType(tf.int32, result_type)
        tensorflow = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=parameter_binding,
            result=result_binding)
        proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            tensorflow=tensorflow)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)

        tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
            building_block)

        self.assertEqual(tf_vars_in_graph, 1)
Exemple #6
0
    def test_gets_all_explicit_placement(self):

        with tf.Graph().as_default() as g:
            with tf.device('/cpu:0'):
                a = tf.constant(0)
                b = tf.constant(1)
                c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        device_placements = building_block_analysis.get_device_placement_in(
            building_block)
        all_device_placements = list(sorted(device_placements.keys()))
        # Expect two placements, the explicit 'cpu' from above, and the empty
        # placement of the `tf.identity` op add to the captured result.
        self.assertLen(all_device_placements, 2)
        self.assertEqual('', sorted(all_device_placements)[0])
        self.assertIn('CPU', sorted(all_device_placements)[1])
        self.assertGreater(device_placements[all_device_placements[1]], 0)
Exemple #7
0
  def test_deserialize_and_call_tf_computation_with_add_one(self):

    with tf.Graph().as_default() as graph:
      parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
          'x', tf.int32, graph)
      result = tf.identity(parameter_value)
      result_type, result_binding = tensorflow_utils.capture_result_from_graph(
          result, graph)
    parameter_type = computation_types.TensorType(tf.int32)
    type_signature = computation_types.FunctionType(parameter_type, result_type)
    tensorflow_proto = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
        parameter=parameter_binding,
        result=result_binding)
    serialized_type = type_serialization.serialize_type(type_signature)
    computation_proto = pb.Computation(
        type=serialized_type, tensorflow=tensorflow_proto)
    init_op, result = tensorflow_utils.deserialize_and_call_tf_computation(
        computation_proto, tf.constant(10), tf.compat.v1.get_default_graph())
    self.assertTrue(tf.is_tensor(result))
    with tf.compat.v1.Session() as sess:
      if init_op:
        sess.run(init_op)
      result_val = sess.run(result)
    self.assertEqual(result_val, 10)
Exemple #8
0
    def test_gets_some_explicit_some_none_placement(self):
        with tf.Graph().as_default() as g:
            with tf.device('/cpu:0'):
                a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        device_placements = building_block_analysis.get_device_placement_in(
            building_block)
        all_device_placements = list(device_placements.keys())
        self.assertLen(all_device_placements, 2)
        if all_device_placements[0]:
            self.assertIn('CPU', all_device_placements[0])
            self.assertEqual('', all_device_placements[1])
        else:
            self.assertIn('CPU', all_device_placements[1])
            self.assertEqual('', all_device_placements[0])
        self.assertGreater(device_placements[all_device_placements[0]], 0)
        self.assertGreater(device_placements[all_device_placements[1]], 0)
def create_indexing_operator(
    operand_type: computation_types.TensorType,
    index_type: computation_types.TensorType,
) -> ComputationProtoAndType:
    """Returns a tensorflow computation computing an indexing operation."""
    operand_type.check_tensor()
    index_type.check_tensor()
    if index_type.shape.rank != 0:
        raise TypeError(
            f'Expected index type to be a scalar, found {index_type}.')
    with tf.Graph().as_default() as graph:
        operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph(
            'indexing_operand', operand_type, graph)
        index_value, index_binding = tensorflow_utils.stamp_parameter_in_graph(
            'index', index_type, graph)
        result_value = tf.gather(operand_value, index_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)
    type_signature = computation_types.FunctionType(
        computation_types.StructType((operand_type, index_type)), result_type)
    parameter_binding = pb.TensorFlow.Binding(
        struct=pb.TensorFlow.StructBinding(
            element=[operand_binding, index_binding]))
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
def create_dummy_computation_tensorflow_tuple(value=10.0):
    """Returns a tensorflow computation and type.

  `( -> <('a', T), ('b', T), ('c', T)>)`

  Args:
    value: An optional integer value.
  """

    with tf.Graph().as_default() as graph:
        names = ['a', 'b', 'c']
        result = anonymous_tuple.AnonymousTuple(
            (n, tf.constant(value)) for n in names)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Exemple #11
0
def create_dummy_computation_tensorflow_add():
    """Returns a tensorflow computation and type.

  `(<float32,float32> -> float32)`
  """
    type_spec = tf.float32

    with tf.Graph().as_default() as graph:
        parameter_1_value, parameter_1_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', type_spec, graph)
        parameter_2_value, parameter_2_binding = tensorflow_utils.stamp_parameter_in_graph(
            'y', type_spec, graph)
        result_value = tf.add(parameter_1_value, parameter_2_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    parameter_type = computation_types.StructType([type_spec, type_spec])
    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    struct_binding = pb.TensorFlow.StructBinding(
        element=[parameter_1_binding, parameter_2_binding])
    parameter_binding = pb.TensorFlow.Binding(struct=struct_binding)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
 def test_pack_unpack_roundtrip(self):
     with tf.Graph().as_default() as g:
         tf.constant(1.0)
     input_value = g.as_graph_def()
     any_pb = serialization_utils.pack_graph_def(input_value)
     output_value = serialization_utils.unpack_graph_def(any_pb)
     self.assertEqual(input_value, output_value)
def create_identity(type_signature: computation_types.Type) -> ProtoAndType:
    """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_signature`. NOTE: if `T` contains `computation_types.StructType`s
  without an associated container type, they will be given the container type
  `tuple` by this function.

  Args:
    type_signature: A `computation_types.Type` to use as the parameter type and
      result type of the identity function.

  Raises:
    TypeError: If `type_signature` contains any types which cannot appear in
      TensorFlow bindings.
  """
    type_analysis.check_tensorflow_compatible_type(type_signature)
    parameter_type = type_signature

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)

    type_signature = computation_types.FunctionType(parameter_type,
                                                    result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Exemple #14
0
  def test_stateful_partitioned_call_nodes(self):

    with tf.Graph().as_default() as graph:
      v = tf.Variable(0)

      @tf.function
      def test():
        return v.assign_add(1)

      result_type, result_binding = tensorflow_utils.capture_result_from_graph(
          test(), graph)

    function_type = computation_types.FunctionType(None, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    proto = pb.Computation(
        type=serialized_function_type,
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=None,
            result=result_binding))

    self.assertCallOpsGrapplerNotDisabled(proto)
    transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
        proto)
    self.assertCallOpsGrapplerDisabled(transformed_proto)
Exemple #15
0
def create_replicate_input(type_spec, count: int) -> pb.Computation:
  """Returns a tensorflow computation which returns `count` clones of an input.

  The returned computation has the type signature `(T -> <T, T, T, ...>)`, where
  `T` is `type_spec` and the length of the result is `count`.

  Args:
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type`.
    count: An integer, the number of times the input is replicated.

  Raises:
    TypeError: If `type_spec` contains any types which cannot appear in
      TensorFlow bindings or if `which` is not an integer.
  """
  type_spec = computation_types.to_type(type_spec)
  type_analysis.check_tensorflow_compatible_type(type_spec)
  py_typecheck.check_type(count, int)

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_spec, graph)
    result = [parameter_value] * count
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        result, graph)

  type_signature = computation_types.FunctionType(type_spec, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
Exemple #16
0
def create_identity(type_spec) -> pb.Computation:
  """Returns a tensorflow computation representing an identity function.

  The returned computation has the type signature `(T -> T)`, where `T` is
  `type_spec`.

  Args:
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type`.

  Raises:
    TypeError: If `type_spec` contains any types which cannot appear in
      TensorFlow bindings.
  """
  type_spec = computation_types.to_type(type_spec)
  type_analysis.check_tensorflow_compatible_type(type_spec)

  with tf.Graph().as_default() as graph:
    parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
        'x', type_spec, graph)
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        parameter_value, graph)

  type_signature = computation_types.FunctionType(type_spec, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=parameter_binding,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
def _pack_noarg_graph(graph_def, return_type, result_binding):
  packed_graph_def = serialization_utils.pack_graph_def(graph_def)
  function_type = computation_types.FunctionType(None, return_type)
  proto = pb.Computation(
      type=type_serialization.serialize_type(function_type),
      tensorflow=pb.TensorFlow(
          graph_def=packed_graph_def, parameter=None, result=result_binding))
  building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
  return building_block
Exemple #18
0
def create_empty_tuple() -> ProtoAndType:
  """Returns a tensorflow computation returning an empty tuple.

  The returned computation has the type signature `( -> <>)`.
  """

  with tf.Graph().as_default() as graph:
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        structure.Struct([]), graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return _tensorflow_comp(tensorflow, type_signature)
Exemple #19
0
def create_dummy_computation_tensorflow_empty():
    """Returns a tensorflow computation and type `( -> <>)`."""

    with tf.Graph().as_default() as graph:
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            [], graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
Exemple #20
0
def disable_grappler_for_partitioned_calls(proto):
    """Disables grappler for `PartitionedCall` and `StatefulPartitionedCall` nodes in the graph.

  TensorFlow serializes a `ConfigProto` into `PartitionedCall` and
  `StatefulPartitionedCall` the `config_proto` `attr` of graph nodes. This
  overrides any session config that might disable runtime grappler. The disable
  grappler for these nodes as well, this function overwrites the serialized
  configproto, setting the `disable_meta_optimizer` field to `True.

  Args:
    proto: Instance of `computation_pb2.Computation` with the `tensorflow` field
      populated.

  Returns:
    A transformed instance of `computation_pb2.Computation` with a `tensorflow`
    field.
  """
    py_typecheck.check_type(proto, computation_pb2.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError('`prune_tensorflow_proto` only accepts `Computation` '
                        'protos of the "tensorflow" variety; you have passed '
                        'one of variety {}.'.format(computation_oneof))
    original_tf = proto.tensorflow
    graph_def = serialization_utils.unpack_graph_def(original_tf.graph_def)
    all_nodes = itertools.chain(
        graph_def.node, *[f.node_def for f in graph_def.library.function])
    for node in all_nodes:
        if node.op not in CALL_OPS:
            continue
        attr_str = node.attr.get('config_proto')
        if attr_str is None:
            config_proto = tf.compat.v1.ConfigProto()
        else:
            config_proto = tf.compat.v1.ConfigProto.FromString(attr_str.s)
        config_proto.graph_options.rewrite_options.disable_meta_optimizer = True
        attr_str.s = config_proto.SerializeToString(deterministic=True)
    tf_block = computation_pb2.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(graph_def),
        initialize_op=original_tf.initialize_op
        if original_tf.initialize_op else None,
        parameter=original_tf.parameter
        if original_tf.HasField('parameter') else None,
        result=original_tf.result)
    new_proto = computation_pb2.Computation(type=proto.type,
                                            tensorflow=tf_block)
    return new_proto
Exemple #21
0
  def test_get_wrapped_function_from_comp_raises_with_incorrect_binding(self):

    with tf.Graph().as_default() as graph:
      var = tf.Variable(initial_value=0.0, name='var1', import_scope='')
      assign_op = var.assign_add(tf.constant(1.0))
      tf.add(1.0, assign_op)

    result_binding = pb.TensorFlow.Binding(
        tensor=pb.TensorFlow.TensorBinding(tensor_name='Invalid'))
    comp = pb.Computation(
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            result=result_binding))
    with self.assertRaises(TypeError):
      wrapped_fn = eager_tf_executor._get_wrapped_function_from_comp(
          comp, must_pin_function_to_cpu=False, param_type=None, device=None)
      wrapped_fn()
Exemple #22
0
def prune_tensorflow_proto(proto):
    """Extracts subgraph from `proto` preserving parameter, result and initialize.

  Args:
    proto: Instance of `pb.Computation` of the `tensorflow` variety whose
      `graphdef` attribute we wish to prune of extraneous ops.

  Returns:
    A transformed instance of `pb.Computation` of the `tensorflow` variety,
    whose `graphdef` attribute contains only ops which can reach the
    parameter or result bindings, or initialize op.
  """
    py_typecheck.check_type(proto, pb.Computation)
    computation_oneof = proto.WhichOneof('computation')
    if computation_oneof != 'tensorflow':
        raise TypeError(
            '`prune_tensorflow_proto` only accepts `Computation` '
            'protos of the \'tensorflow\' variety; you have passed '
            'one of variety {}.'.format(computation_oneof))
    if proto.tensorflow.parameter.WhichOneof('binding'):
        parameter_tensor_names = graph_utils.extract_tensor_names_from_binding(
            proto.tensorflow.parameter)
        parameter_names = [
            ':'.join(x.split(':')[:-1]) for x in parameter_tensor_names
        ]
    else:
        parameter_names = []
    return_tensor_names = graph_utils.extract_tensor_names_from_binding(
        proto.tensorflow.result)
    return_names = [':'.join(x.split(':')[:-1]) for x in return_tensor_names]
    graph_def = serialization_utils.unpack_graph_def(
        proto.tensorflow.graph_def)
    init_op_name = proto.tensorflow.initialize_op
    names_to_preserve = parameter_names + return_names
    if init_op_name:
        names_to_preserve.append(init_op_name)
    subgraph_def = tf.compat.v1.graph_util.extract_sub_graph(
        graph_def, names_to_preserve)
    tf_block = pb.TensorFlow(
        graph_def=serialization_utils.pack_graph_def(subgraph_def),
        initialize_op=proto.tensorflow.initialize_op,
        parameter=proto.tensorflow.parameter,
        result=proto.tensorflow.result)
    pruned_proto = pb.Computation(type=proto.type, tensorflow=tf_block)
    return pruned_proto
Exemple #23
0
def create_dummy_computation_tensorflow_identity(type_spec=tf.int32):
    """Returns a tensorflow computation and type `(T -> T)`."""

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'a', type_spec, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)

    type_signature = computation_types.FunctionType(type_spec, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature
def create_empty_tuple() -> pb.Computation:
  """Returns a tensorflow computation returning an empty tuple.

  The returned computation has the type signature `( -> <>)`.
  """

  with tf.Graph().as_default() as graph:
    result_type, result_binding = tensorflow_utils.capture_result_from_graph(
        [], graph)

  type_signature = computation_types.FunctionType(None, result_type)
  tensorflow = pb.TensorFlow(
      graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
      parameter=None,
      result=result_binding)
  return pb.Computation(
      type=type_serialization.serialize_type(type_signature),
      tensorflow=tensorflow)
def _create_proto_with_unnecessary_op():
    parameter_type = tf.int32

    with tf.Graph().as_default() as graph:
        parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', parameter_type, graph)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            parameter_value, graph)
        unnecessary_op = tf.constant(0)
        tensorflow_utils.capture_result_from_graph(unnecessary_op, graph)

    function_type = computation_types.FunctionType(parameter_type, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    return pb.Computation(type=serialized_function_type,
                          tensorflow=pb.TensorFlow(
                              graph_def=serialization_utils.pack_graph_def(
                                  graph.as_graph_def()),
                              parameter=parameter_binding,
                              result=result_binding))
  def test_counts_no_variables(self):

    with tf.Graph().as_default() as g:
      a = tf.constant(0)
      b = tf.constant(1)
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
        building_block)
    self.assertEqual(tf_vars_in_graph, 0)
def create_unary_operator(
        operator,
        operand_type: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation computing a unary operation.

  The returned computation has the type signature `(T -> U)`, where `T` is
  `operand_type` and `U` is the result of applying the `operator` to a value of
  type `T`

  Args:
    operator: A callable taking one argument representing the operation to
      encode For example: `tf.math.abs`.
    operand_type: A `computation_types.Type` to use as the argument to the
      constructed unary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `operand_type` are violated or `operator`
      is not callable.
  """
    if (operand_type is None
            or not type_analysis.is_generic_op_compatible_type(operand_type)):
        raise TypeError(
            '`operand_type` contains a type other than '
            '`computation_types.TensorType` and `computation_types.StructType`; '
            f'this is disallowed in the generic operators. Got: {operand_type} '
        )
    py_typecheck.check_callable(operator)

    with tf.Graph().as_default() as graph:
        operand_value, operand_binding = tensorflow_utils.stamp_parameter_in_graph(
            'x', operand_type, graph)
        result_value = operator(operand_value)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result_value, graph)

    type_signature = computation_types.FunctionType(operand_type, result_type)
    parameter_binding = operand_binding
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=parameter_binding,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
    def test_counts_two_variables_correctly(self):

        with tf.Graph().as_default() as g:
            a = tf.Variable(0, name='variable1')
            b = tf.Variable(1, name='variable2')
            c = a + b

        _, result_binding = graph_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = computation_building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_vars_in_graph = computation_building_block_utils.count_tensorflow_variables_in(
            building_block)
        self.assertEqual(tf_vars_in_graph, 2)
Exemple #29
0
    def test_counts_correct_number_of_ops_simple_case(self):

        with tf.Graph().as_default() as g:
            a = tf.constant(0)
            b = tf.constant(1)
            c = a + b

        _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

        packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
        function_type = computation_types.FunctionType(None, tf.int32)
        proto = pb.Computation(
            type=type_serialization.serialize_type(function_type),
            tensorflow=pb.TensorFlow(graph_def=packed_graph_def,
                                     parameter=None,
                                     result=result_binding))
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_ops_in_graph = building_block_analysis.count_tensorflow_ops_in(
            building_block)
        # Expect 4 ops: two constants, one addition, and an identity on the result.
        self.assertEqual(tf_ops_in_graph, 4)
def create_whimsy_computation_tensorflow_tuple():
    """Returns a tensorflow computation and type.

  `( -> <('a', float32), ('b', float32), ('c', float32)>)`
  """
    value = 10.0

    with tf.Graph().as_default() as graph:
        names = ['a', 'b', 'c']
        result = structure.Struct((n, tf.constant(value)) for n in names)
        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    value = pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)
    return value, type_signature