def test_identify_input_output(self): g = GraphAnalyzer() g.graph = self.input_graph g.parse_graph() inputs, outputs = g.get_graph_input_output() self.assertEqual(inputs, self.inputs) self.assertEqual(outputs, self.outputs)
def test_no_input_output_config(self): g = GraphAnalyzer() g.graph = self.input_graph g.parse_graph() float_graph_def = g.dump_graph() from lpot import Quantization, common quantizer = Quantization('fake_yaml.yaml') dataset = quantizer.dataset('dummy', shape=(20, 224, 224, 3), label=True) quantizer.calib_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.eval_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.model = float_graph_def output_graph = quantizer() self.assertGreater(len(output_graph.graph_def.node), 0)
def test_replace_node(self): graph_analyzer = GraphAnalyzer() graph_analyzer.graph = copy.deepcopy(self.graph_def) graph_analyzer.parse_graph() new_add_node = node_def_pb2.NodeDef() new_add_node.op = "Add" new_add_node.name = "add1" new_add_node.input.extend( [self.input0_node.name, self.input1_node.name]) graph_analyzer.replace_node(new_add_node, self.add_node.name, [self.mul_node.name]) result_graph = graph_analyzer.dump_graph() assert self.add_node not in list(result_graph.node) assert new_add_node in list(result_graph.node)
def test_invalid_input_output_config(self): g = GraphAnalyzer() g.graph = self.input_graph g.parse_graph() float_graph_def = g.dump_graph() from lpot import Quantization, common quantizer = Quantization('fake_yaml_2.yaml') dataset = quantizer.dataset('dummy', shape=(20, 224, 224, 3), label=True) quantizer.calib_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.eval_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.model = float_graph_def model = quantizer() # will detect the right inputs/outputs self.assertNotEqual(model.input_node_names, ['x']) self.assertNotEqual(model.output_node_names, ['op_to_store'])
def test_replace_constant_graph_with_constant_node(self): graph_analyzer = GraphAnalyzer() graph_analyzer.graph = copy.deepcopy(self.graph_def) graph_analyzer.parse_graph() new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.add_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.add_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 10 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.mul_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.mul_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 8 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.sqrt_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.sqrt_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 7 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.block_node.name + "_const", new_constant_value, new_constant_type) assert not graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.block_node.name)
def test_identify_input_output(self): g = GraphAnalyzer() g.graph = self.input_graph g.parse_graph() inputs, outputs = g.get_graph_input_output() self.assertEqual(inputs, self.inputs) self.assertEqual(outputs, self.outputs) input_graph = tf.compat.v1.GraphDef() with open('model_1.pb', "rb") as f: input_graph.ParseFromString(f.read()) g = GraphAnalyzer() g.graph = input_graph g.parse_graph() inputs, outputs = g.get_graph_input_output() self.assertEqual(inputs, ['sub']) self.assertEqual(outputs, ['op_to_store']) input_graph = tf.compat.v1.GraphDef() with open('model_2.pb', "rb") as f: input_graph.ParseFromString(f.read()) g = GraphAnalyzer() g.graph = input_graph g.parse_graph() inputs, outputs = g.get_graph_input_output() self.assertEqual(inputs, []) self.assertEqual(outputs, []) input_graph = tf.compat.v1.GraphDef() with open('model_3.pb', "rb") as f: input_graph.ParseFromString(f.read()) g = GraphAnalyzer() g.graph = input_graph g.parse_graph() inputs, outputs = g.get_graph_input_output() self.assertEqual(inputs, []) self.assertEqual(outputs, [])
def test_tensorflow_concat_quantization(self): output_graph_def = read_graph(self.pb_path) from lpot import Quantization quantizer = Quantization('fake_yaml.yaml') dataset = quantizer.dataset('dummy', shape=(100, 299, 299, 3), label=True) dataloader = quantizer.dataloader(dataset) output_graph = quantizer(output_graph_def, q_dataloader=dataloader, eval_dataloader=dataloader) found_quantized_concat_node = False target_concat_node_name = 'v0/cg/incept_v3_a0/concat_eightbit_quantized_concatv2' from lpot.adaptor.tf_utils.graph_rewriter.graph_util import GraphAnalyzer cur_graph = GraphAnalyzer() cur_graph.graph = output_graph.as_graph_def() graph_info = cur_graph.parse_graph() found_quantized_concat_node = target_concat_node_name in graph_info self.assertEqual(found_quantized_concat_node, True) min_out, max_out = [], [] for input_conv_name in graph_info[ target_concat_node_name].node.input[:4]: # print (input_conv_name, graph_info[input_conv_name].node.input) min_freezed_out_name = graph_info[input_conv_name].node.input[-2] max_freezed_out_name = graph_info[input_conv_name].node.input[-1] min_freezed_out_value = (graph_info[min_freezed_out_name].node. attr['value'].tensor.float_val)[0] max_freezed_out_value = (graph_info[max_freezed_out_name].node. attr['value'].tensor.float_val)[0] min_out.append(min_freezed_out_value) max_out.append(max_freezed_out_value) self.assertEqual(len(set(min_out)), 1) self.assertEqual(len(set(max_out)), 1)
class PreOptimization(object): def __init__(self, model, inputs, outputs): self.output_node_names = list( set([output.split(":")[0] for output in outputs])) self.input_graph = get_graph_def(model, self.output_node_names) if 'MakeIterator' in [node.op for node in self.input_graph.node]: self.output_node_names.append('MakeIterator') self.analyzer = GraphAnalyzer() self.analyzer.graph = self.input_graph self.analyzer.parse_graph() self.input_node_names = inputs self.logger = logging.getLogger() self._tmp_graph_def = None self._excluded_node_names = [] if not self.input_node_names or not self.output_node_names: self.input_node_names, self.output_node_names = self.analyzer.get_graph_input_output( ) def get_excluded_node_names(self): """Get the excluded node name Returns: string list: the excluded ops' name """ return self._excluded_node_names @dump_elapsed_time("Pass Pre Optimization") def get_optimized_graphdef(self): """Executed the non-precision dependant graph optimization. The input graph will be optimized with following passes: 1. Remove the training nodes like Identity Op. 2. Split the shared nodes like weights node for multi-Conv2d. 3. Fold Constant Nodes as less as possible. 4. Fuse the Mul node into the previous Conv2D/MatMul if possible. 5. Strip the useless nodes. 6. Do the Common sequence elimation optimization on the graph. 7. Fold the BN node into the previous Conv2D if possible. Returns: [graphdef]: the optimized graphdef object. """ self.logger.debug("Start to pre optimize input model...") self._tmp_graph_def = ConvertLayoutOptimizer( self.input_graph, self.output_node_names).do_transformation() self._tmp_graph_def = RemoveTrainingNodesOptimizer( self._tmp_graph_def, protected_nodes=self.output_node_names).do_transformation() self._tmp_graph_def = SplitSharedInputOptimizer( self._tmp_graph_def).do_transformation() self._tmp_graph_def = GraphFoldConstantOptimizer( self._tmp_graph_def).do_transformation() self._tmp_graph_def = FuseColumnWiseMulOptimizer( self._tmp_graph_def).do_transformation() self._tmp_graph_def = StripUnusedNodesOptimizer( self._tmp_graph_def, self.input_node_names, self.output_node_names).do_transformation() self._tmp_graph_def = GraphCseOptimizer( self._tmp_graph_def).do_transformation() self._tmp_graph_def = FoldBatchNormNodesOptimizer( self._tmp_graph_def).do_transformation() #TODO we should handle all control ops elegantly not bypass it. self._tmp_graph_def, excluded_node_names = UpdateEnterOptimizer( self._tmp_graph_def).do_transformation() self._excluded_node_names.extend(excluded_node_names) self._tmp_graph_def.library.CopyFrom(self.input_graph.library) return self._tmp_graph_def def get_matched_nodes(self, patterns): """Searche the matched nodes with the specified patterns Args: patterns ([string list]): The patterns should be illustrated as below. [['MatMul'], ("BiasAdd"), ("Relu",)] Returns: [string list]: It will return the list that contains the matched nodes name and pattern. ['matched_node_a_name', 'matched_node_a_name',['MatMul','BiasAdd']] """ self.analyzer.graph = self._tmp_graph_def self.analyzer.parse_graph() res = [] for sub_pattern in patterns: res.extend(self.analyzer.query_fusion_pattern_nodes(sub_pattern)) return res def has_positive_input(self, node_name): """Check the specified node has the positive input or not. Args: node_name ([string]): node's name Returns: [bool]: True if the node has the positive input data, False if the node has the negative input data. """ return self.analyzer.has_positive_input(node_name)
class PreOptimization(): def __init__(self, model, optimization): self.model = model self.optimization = optimization self.output_node_names = model.output_node_names self.input_node_names = model.input_node_names if model.iter_op is not None: self.output_node_names.append('MakeIterator') self.analyzer = GraphAnalyzer() self.analyzer.graph = model.graph_def self.analyzer.parse_graph() self.logger = logging.getLogger() self._tmp_graph_def = None self._excluded_node_names = [] def get_excluded_node_names(self): """Get the excluded node name Returns: string list: the excluded ops' name """ return self._excluded_node_names @dump_elapsed_time("Pass Pre Optimization") def get_optimized_model(self): """Executed the non-precision dependant graph optimization. The input graph will be optimized with following passes: 1. Remove the training nodes like Identity Op. 2. Split the shared nodes like weights node for multi-Conv2d. 3. Fold Constant Nodes as less as possible. 4. Fuse the Mul node into the previous Conv2D/MatMul if possible. 5. Strip the useless nodes. 6. Do the Common sequence elimation optimization on the graph. 7. Fold the BN node into the previous Conv2D if possible. Returns: [graphdef]: the optimized graphdef object. """ self.logger.debug("Start to pre optimize input model...") origin_model = TensorflowModel(self.model._model, self.model.framework_specific_info, **self.model.kwargs) self._tmp_graph_def = ConvertLayoutOptimizer( self.model.graph_def, self.output_node_names).do_transformation() self._tmp_graph_def = GrapplerOptimizer( self._tmp_graph_def, self.output_node_names, self.optimization).do_transformation() self._tmp_graph_def = RemoveTrainingNodesOptimizer( self._tmp_graph_def, protected_nodes=self.output_node_names).do_transformation() self._tmp_graph_def = SplitSharedInputOptimizer(self._tmp_graph_def).do_transformation() self._tmp_graph_def = GraphFoldConstantOptimizer(self._tmp_graph_def).do_transformation() self._tmp_graph_def = FuseColumnWiseMulOptimizer(self._tmp_graph_def).do_transformation() self._tmp_graph_def = StripUnusedNodesOptimizer(self._tmp_graph_def, self.input_node_names, self.output_node_names).do_transformation() self._tmp_graph_def = FuseGeluOptimizer(self._tmp_graph_def).do_transformation() self._tmp_graph_def = GraphCseOptimizer(self._tmp_graph_def).do_transformation() self._tmp_graph_def = FoldBatchNormNodesOptimizer( self._tmp_graph_def).do_transformation() #TODO we should handle all control ops elegantly not bypass it. self._tmp_graph_def, excluded_node_names = UpdateEnterOptimizer( self._tmp_graph_def).do_transformation() #TODO we need to remove below optimizer once the TF enabled the single # matmul op quantization self._tmp_graph_def = InjectDummyBiasAddOptimizer( self._tmp_graph_def).do_transformation() self._tmp_graph_def = FuseTransposeReshapeOptimizer( self._tmp_graph_def).do_transformation() self._excluded_node_names.extend(excluded_node_names) self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library) origin_model.graph_def = self._tmp_graph_def return origin_model def get_matched_nodes(self, patterns): """Searche the matched nodes with the specified patterns Args: patterns ([string list]): The patterns should be illustrated as below. [['MatMul'], ("BiasAdd"), ("Relu",)] Returns: [string list]: It will return the list that contains the matched nodes name and pattern. ['matched_node_a_name', 'matched_node_a_name',['MatMul','BiasAdd']] """ self.analyzer.graph = self._tmp_graph_def self.analyzer.parse_graph() res = [] for sub_pattern in patterns: res.extend([i for i in self.analyzer.query_fusion_pattern_nodes( sub_pattern) if i not in res]) return res def has_positive_input(self, node_name): """Check the specified node has the positive input or not. Args: node_name ([string]): node's name Returns: [bool]: True if the node has the positive input data, False if the node has the negative input data. """ return self.analyzer.has_positive_input(node_name)
def test_graph_cse(self): tf.compat.v1.disable_eager_execution() input_constant_name = "input_constant" relu_name = "relu" float_graph_def = graph_pb2.GraphDef() input_constant = QuantizeGraphHelper.create_constant_node( input_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.float32, shape=[1, 2, 6, 1]) float_graph_def.node.extend([input_constant]) relu_node = QuantizeGraphHelper.create_node("Relu", relu_name, [input_constant_name]) QuantizeGraphHelper.set_attr_dtype(relu_node, "T", dtypes.float32) float_graph_def.node.extend([relu_node]) b_constant_name = "b_constant" mat_mul_name = "mat_mul" b_constant = QuantizeGraphHelper.create_constant_node( b_constant_name, value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtypes.float32, shape=[2, 6]) float_graph_def.node.extend([b_constant]) mat_mul_node = QuantizeGraphHelper.create_node( "MatMul", mat_mul_name, [relu_name, b_constant_name]) QuantizeGraphHelper.set_attr_dtype(mat_mul_node, "T", dtypes.float32) QuantizeGraphHelper.set_attr_bool(mat_mul_node, "transpose_a", False) QuantizeGraphHelper.set_attr_bool(mat_mul_node, "transpose_b", False) float_graph_def.node.extend([mat_mul_node]) bias_add_name = "bias_add" offset_constant_name = "offset_constant" offset_constant = QuantizeGraphHelper.create_constant_node( offset_constant_name, value=[1, 2, 3, 4, 5, 6], dtype=dtypes.float32, shape=[6]) float_graph_def.node.extend([offset_constant]) bias_add_node = QuantizeGraphHelper.create_node( "BiasAdd", bias_add_name, [mat_mul_name, offset_constant_name]) QuantizeGraphHelper.set_attr_dtype(bias_add_node, "T", dtypes.float32) float_graph_def.node.extend([bias_add_node]) post_relu_name = "post_relu" post_relu_node = QuantizeGraphHelper.create_node( "Relu", post_relu_name, [bias_add_name]) float_graph_def.node.extend([post_relu_node]) last_identity_node_name = 'last_identity' last_identity_node = QuantizeGraphHelper.create_node( "Identity", last_identity_node_name, [post_relu_name]) float_graph_def.node.extend([last_identity_node]) analyzer = GraphAnalyzer() analyzer.graph = float_graph_def analyzer.parse_graph() res = analyzer.query_fusion_pattern_nodes([['MatMul'], ("BiasAdd"), ("Relu")]) self.assertEqual(3, len(res[0][-1]))