def test_replace_constant_graph_with_constant_node(self):
        graph_analyzer = GraphAnalyzer()
        graph_analyzer.graph = copy.deepcopy(self.graph_def)

        graph_analyzer.parse_graph()

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.add_node.name + "_const", new_constant_value,
            new_constant_type)
        assert graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.add_node.name)
        result_graph = graph_analyzer.dump_graph()
        assert len(list(result_graph.node)) == 10

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.mul_node.name + "_const", new_constant_value,
            new_constant_type)
        assert graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.mul_node.name)
        result_graph = graph_analyzer.dump_graph()
        assert len(list(result_graph.node)) == 8

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.sqrt_node.name + "_const", new_constant_value,
            new_constant_type)
        assert graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.sqrt_node.name)
        result_graph = graph_analyzer.dump_graph()
        assert len(list(result_graph.node)) == 7

        new_constant_value = np.random.random([4, 1])
        new_constant_type = tf.as_dtype(np.float32(new_constant_value).dtype)
        new_constant_node = GraphRewriterHelper.create_constant_node(
            self.block_node.name + "_const", new_constant_value,
            new_constant_type)
        assert not graph_analyzer.replace_constant_graph_with_constant_node(
            new_constant_node, self.block_node.name)
    def test_no_input_output_config(self):
        g = GraphAnalyzer()
        g.graph = self.input_graph
        g.parse_graph()

        float_graph_def = g.dump_graph()
        from lpot import Quantization, common

        quantizer = Quantization('fake_yaml.yaml')
        dataset = quantizer.dataset('dummy', shape=(20, 224, 224, 3), label=True)
        quantizer.calib_dataloader = common.DataLoader(dataset, batch_size=2)
        quantizer.eval_dataloader = common.DataLoader(dataset, batch_size=2)
        quantizer.model = float_graph_def
        output_graph = quantizer()
        self.assertGreater(len(output_graph.graph_def.node), 0)
    def test_replace_node(self):
        graph_analyzer = GraphAnalyzer()
        graph_analyzer.graph = copy.deepcopy(self.graph_def)

        graph_analyzer.parse_graph()

        new_add_node = node_def_pb2.NodeDef()
        new_add_node.op = "Add"
        new_add_node.name = "add1"
        new_add_node.input.extend(
            [self.input0_node.name, self.input1_node.name])
        graph_analyzer.replace_node(new_add_node, self.add_node.name,
                                    [self.mul_node.name])
        result_graph = graph_analyzer.dump_graph()
        assert self.add_node not in list(result_graph.node)
        assert new_add_node in list(result_graph.node)
    def test_invalid_input_output_config(self):
        g = GraphAnalyzer()
        g.graph = self.input_graph
        g.parse_graph()

        float_graph_def = g.dump_graph()
        from lpot import Quantization, common

        quantizer = Quantization('fake_yaml_2.yaml')
        dataset = quantizer.dataset('dummy', shape=(20, 224, 224, 3), label=True)
        quantizer.calib_dataloader = common.DataLoader(dataset, batch_size=2)
        quantizer.eval_dataloader = common.DataLoader(dataset, batch_size=2)
        quantizer.model = float_graph_def
        model = quantizer()
        # will detect the right inputs/outputs
        self.assertNotEqual(model.input_node_names, ['x'])
        self.assertNotEqual(model.output_node_names, ['op_to_store'])