def test_replace_constant_graph_with_constant_node(self): graph_analyzer = GraphAnalyzer() graph_analyzer.graph = copy.deepcopy(self.graph_def) graph_analyzer.parse_graph() new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.add_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.add_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 10 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.mul_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.mul_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 8 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.sqrt_node.name + "_const", new_constant_value, new_constant_type) assert graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.sqrt_node.name) result_graph = graph_analyzer.dump_graph() assert len(list(result_graph.node)) == 7 new_constant_value = np.random.random([4, 1]) new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype) new_constant_node = GraphRewriterHelper.create_constant_node( self.block_node.name + "_const", new_constant_value, new_constant_type) assert not graph_analyzer.replace_constant_graph_with_constant_node( new_constant_node, self.block_node.name)
def test_no_input_output_config(self): g = GraphAnalyzer() g.graph = self.input_graph g.parse_graph() float_graph_def = g.dump_graph() from lpot import Quantization, common quantizer = Quantization('fake_yaml.yaml') dataset = quantizer.dataset('dummy', shape=(20, 224, 224, 3), label=True) quantizer.calib_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.eval_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.model = float_graph_def output_graph = quantizer() self.assertGreater(len(output_graph.graph_def.node), 0)
def test_replace_node(self): graph_analyzer = GraphAnalyzer() graph_analyzer.graph = copy.deepcopy(self.graph_def) graph_analyzer.parse_graph() new_add_node = node_def_pb2.NodeDef() new_add_node.op = "Add" new_add_node.name = "add1" new_add_node.input.extend( [self.input0_node.name, self.input1_node.name]) graph_analyzer.replace_node(new_add_node, self.add_node.name, [self.mul_node.name]) result_graph = graph_analyzer.dump_graph() assert self.add_node not in list(result_graph.node) assert new_add_node in list(result_graph.node)
def test_invalid_input_output_config(self): g = GraphAnalyzer() g.graph = self.input_graph g.parse_graph() float_graph_def = g.dump_graph() from lpot import Quantization, common quantizer = Quantization('fake_yaml_2.yaml') dataset = quantizer.dataset('dummy', shape=(20, 224, 224, 3), label=True) quantizer.calib_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.eval_dataloader = common.DataLoader(dataset, batch_size=2) quantizer.model = float_graph_def model = quantizer() # will detect the right inputs/outputs self.assertNotEqual(model.input_node_names, ['x']) self.assertNotEqual(model.output_node_names, ['op_to_store'])