예제 #1
0
  def test_stateful_partitioned_call_nodes(self):

    with tf.Graph().as_default() as graph:
      v = tf.Variable(0)

      @tf.function
      def test():
        return v.assign_add(1)

      result_type, result_binding = tensorflow_utils.capture_result_from_graph(
          test(), graph)

    function_type = computation_types.FunctionType(None, result_type)
    serialized_function_type = type_serialization.serialize_type(function_type)
    proto = pb.Computation(
        type=serialized_function_type,
        tensorflow=pb.TensorFlow(
            graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
            parameter=None,
            result=result_binding))

    self.assertCallOpsGrapplerNotDisabled(proto)
    transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
        proto)
    self.assertCallOpsGrapplerDisabled(transformed_proto)
예제 #2
0
 def transform(self, comp):
   if not self.should_transform(comp):
     return comp, False
   py_typecheck.check_type(comp, building_blocks.CompiledComputation)
   new_comp_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
       comp.proto)
   return building_blocks.CompiledComputation(
       new_comp_proto, type_signature=comp.type_signature), True
예제 #3
0
 def test_raises_on_compiled_computation(self):
   tensor_type = computation_types.TensorType(tf.int32)
   comp = building_block_factory.create_compiled_identity(tensor_type)
   with self.assertRaises(TypeError):
     tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
         comp)
예제 #4
0
 def test_raises_on_none(self):
   with self.assertRaises(TypeError):
     tensorflow_computation_transformations.disable_grappler_for_partitioned_calls(
         None)