def test_serialize_graph(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.e = torch.rand(4)
                self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)

            def forward(self, a, b, c):
                add_1 = a + b
                conv1 = self.conv(c)
                linear = self.linear(add_1 + conv1)
                add_2 = linear + self.e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        c = torch.rand(3, 3, 2, 2)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])

        partitioner = Partitioner()
        devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.sparse_nn)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        # Fix for now to add type/shape to output
        for node in traced.graph.nodes:
            if node.op == "output":
                node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for mod in module_with_submodules.modules():
            if isinstance(mod, GraphModule):
                for node in mod.graph.nodes:
                    node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for node in module_with_submodules.graph.nodes:
            node.meta['tensor_meta'] = extract_tensor_metadata(a)

        weights1 = {}
        weights2 = {}
        serialized_graph1 = graph_manipulation.serialize_module(
            traced, weights1)
        serialized_graph2 = graph_manipulation.serialize_module(
            module_with_submodules, weights2)
        assert len(weights1) == 4
        assert len(weights2) == 4
        assert len(serialized_graph1["nodes"]) == 10
        assert len(serialized_graph1["weights"]) == 4
        assert len(serialized_graph1["modules"]) == 0
        assert len(serialized_graph2["nodes"]) == 6
        assert len(serialized_graph2["weights"]) == 4
        assert len(serialized_graph2["modules"]) == 1
        assert serialized_graph1["weights"]["linear.weight"][
            "shape"] == "[4, 4]"
        assert (serialized_graph1["weights"]["linear.weight"]["dtype"] ==
                "torch.float32")
        assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"]
                is False)
        assert serialized_graph1["nodes"][0]["shape"] == "[4]"
        assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32"
        assert serialized_graph1["nodes"][0]["target"] == "a"
        assert serialized_graph1["nodes"][0]["op_code"] == "placeholder"
        assert serialized_graph1["nodes"][0]["name"] == "a"
        assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1"
        assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True

        # Test quantization info serialization.
        x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
        q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
        q_tensor_channel = torch.quantize_per_channel(
            x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0,
            torch.quint8)
        result = graph_manipulation.serialize_tensor_quantization(q_tensor)
        result2 = graph_manipulation.serialize_tensor_quantization(
            q_tensor_channel)
        assert result["qscheme"] == "torch.per_tensor_affine"
        assert result["q_scale"] == 1.0
        assert result2["qscheme"] == "torch.per_channel_affine"
        assert len(result2["q_per_channel_scales"]) == 2
Пример #2
0
def serialize_module_json_to_file(fx_module: GraphModule, fname: str):
    weights: Dict = {}
    serialized_json = json.dumps(serialize_module(fx_module, weights),
                                 indent=2)
    with open(fname, "w") as ofile:
        ofile.write(serialized_json)