def test_stateful_partitioned_call_nodes(self): with tf.Graph().as_default() as graph: v = tf.Variable(0) @tf.function def test(): return v.assign_add(1) result_type, result_binding = tensorflow_utils.capture_result_from_graph( test(), graph) function_type = computation_types.FunctionType(None, result_type) serialized_function_type = type_serialization.serialize_type(function_type) proto = pb.Computation( type=serialized_function_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding)) self.assertCallOpsGrapplerNotDisabled(proto) transformed_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( proto) self.assertCallOpsGrapplerDisabled(transformed_proto)
def transform(self, comp): if not self.should_transform(comp): return comp, False py_typecheck.check_type(comp, building_blocks.CompiledComputation) new_comp_proto = tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( comp.proto) return building_blocks.CompiledComputation( new_comp_proto, type_signature=comp.type_signature), True
def test_raises_on_compiled_computation(self): tensor_type = computation_types.TensorType(tf.int32) comp = building_block_factory.create_compiled_identity(tensor_type) with self.assertRaises(TypeError): tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( comp)
def test_raises_on_none(self): with self.assertRaises(TypeError): tensorflow_computation_transformations.disable_grappler_for_partitioned_calls( None)