Ejemplo n.º 1
0
    def test_counts_correct_variables_with_function(self):
        @tf.function
        def add_one(x):
            with tf.init_scope():
                y = tf.Variable(1)
            return x + y

        with tf.Graph().as_default() as graph:
            parameter_value, parameter_binding = tensorflow_utils.stamp_parameter_in_graph(
                'x', tf.int32, graph)
            result = add_one(add_one(parameter_value))

        result_type, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)
        type_signature = computation_types.FunctionType(tf.int32, result_type)
        tensorflow = pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=parameter_binding,
            result=result_binding)
        proto = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            tensorflow=tensorflow)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)

        tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
            building_block)

        self.assertEqual(tf_vars_in_graph, 1)
Ejemplo n.º 2
0
 def _count_tf_vars(inner_comp):
     if (isinstance(inner_comp, building_blocks.CompiledComputation) and
             inner_comp.proto.WhichOneof('computation') == 'tensorflow'):
         total_tf_vars[
             0] += building_block_analysis.count_tensorflow_variables_in(
                 inner_comp)
     return inner_comp, False
Ejemplo n.º 3
0
 def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self):
   two_variable_comp = _create_two_variable_tensorflow()
   node_tf_variable_count = building_block_analysis.count_tensorflow_variables_in(
       two_variable_comp)
   tf_tuple = building_blocks.Struct([two_variable_comp, two_variable_comp])
   tree_tf_variable_count = tree_analysis.count_tensorflow_variables_under(
       tf_tuple)
   self.assertEqual(tree_tf_variable_count, 2 * node_tf_variable_count)
Ejemplo n.º 4
0
    def test_counts_correct_variables_with_function(self):
        @computations.tf_computation(tf.int32)
        def foo(x):

            y = tf.Variable(initial_value=0)

            @tf.function
            def bar(x):
                y.assign_add(1)
                return x + y, tf.shape(y)

            z = bar(x)
            return bar(z[0])

        building_block = foo.to_building_block()
        tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
            building_block)
        self.assertEqual(tf_vars_in_graph, 1)
  def test_counts_no_variables(self):

    with tf.Graph().as_default() as g:
      a = tf.constant(0)
      b = tf.constant(1)
      c = a + b

    _, result_binding = tensorflow_utils.capture_result_from_graph(c, g)

    packed_graph_def = serialization_utils.pack_graph_def(g.as_graph_def())
    function_type = computation_types.FunctionType(None, tf.int32)
    proto = pb.Computation(
        type=type_serialization.serialize_type(function_type),
        tensorflow=pb.TensorFlow(
            graph_def=packed_graph_def, parameter=None, result=result_binding))
    building_block = building_blocks.ComputationBuildingBlock.from_proto(proto)
    tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
        building_block)
    self.assertEqual(tf_vars_in_graph, 0)
    def test_counts_correct_variables_with_function(self):
        @computations.tf_computation(tf.int32)
        def foo(x):

            y = tf.Variable(initial_value=0)

            @tf.function
            def bar(x):
                y.assign_add(1)
                return x + y, tf.shape(y)

            z = bar(x)
            return bar(z[0])

        proto = computation_impl.ComputationImpl.get_proto(foo)
        building_block = building_blocks.ComputationBuildingBlock.from_proto(
            proto)
        tf_vars_in_graph = building_block_analysis.count_tensorflow_variables_in(
            building_block)
        self.assertEqual(tf_vars_in_graph, 1)
Ejemplo n.º 7
0
 def _count_tf_vars(inner_comp):
     nonlocal count_vars
     if (inner_comp.is_compiled_computation() and
             inner_comp.proto.WhichOneof('computation') == 'tensorflow'):
         count_vars += building_block_analysis.count_tensorflow_variables_in(
             inner_comp)
Ejemplo n.º 8
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         building_block_analysis.count_tensorflow_variables_in(None)