Exemplo n.º 1
0
    def test_size_based_partition(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a, b):
                add_1 = a + b
                linear = self.linear(add_1)
                e = torch.rand(4)
                add_2 = linear + e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2)
        ]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        for i, node in enumerate(dag.nodes):
            assert node.logical_device_ids == [i]
Exemplo n.º 2
0
    def test_partition_combining(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                b = torch.rand(4)
                add_1 = a + b
                linear_1 = self.linear(add_1)
                add_2 = torch.rand(4) + a
                add_3 = add_2 + linear_1
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [Device('dev_0', 120, 0), Device('dev_1', 144, 1)]
        partitioner_config = PartitionerConfig(devices, is_sparse_nn=False)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        assert dag.nodes[0].logical_device_ids == [0]
        assert dag.nodes[0].size_bytes == 80
        assert dag.nodes[1].logical_device_ids == [1]
        assert dag.nodes[1].size_bytes == 144
Exemplo n.º 3
0
    def test_partition_device_mapping(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                b = torch.rand(4)
                add_1 = a + b
                linear_1 = self.linear(add_1)
                add_2 = torch.rand(4) + a
                add_3 = add_2 + linear_1
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
        partitioner_config = PartitionerConfig(devices, is_sparse_nn=False)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        for i, node in enumerate(dag.nodes):
            if i == 1:
                assert node.logical_device_ids == [1]
            else:
                assert node.logical_device_ids == [0]
Exemplo n.º 4
0
    def test_partition_combining(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear_0 = torch.nn.Linear(4, 4)

            def forward(self, a, b):
                add_1 = a + b
                c = self.linear_0(a)
                add_2 = c + add_1
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2)
        ]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert len(module_with_submodules.graph.nodes) == 5
Exemplo n.º 5
0
    def test_size_based_partition(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a, b):
                add_1 = a + b
                linear = self.linear(add_1)
                e = torch.rand(4)
                add_2 = linear + e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125),
            Device('dev_1', 125),
            Device('dev_2', 125)
        ]
        ret = partitioner.partition_graph(traced, m, devices)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert len(module_with_submodules.graph.nodes) == 7
Exemplo n.º 6
0
    def test_sparse_nn_partition(self):
        class MyRecommendationModule(torch.nn.Module):
            def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
                layers = torch.nn.ModuleList()
                for _ in range(num_of_layers):
                    ll = torch.nn.Linear(input_size, output_size)
                    layers.append(ll)
                    layers.append(torch.nn.ReLU())
                return layers

            def __init__(self):
                super(MyRecommendationModule, self).__init__()
                layers = self.create_mlp(4, 4, 4)
                self.bottom_layers = torch.nn.Sequential(*layers)
                layers = self.create_mlp(3, 24, 24)
                self.top_layers = torch.nn.Sequential(*layers)
                self.embedding_layers = torch.nn.ModuleList()
                el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True)
                self.embedding_layers.append(el)
                for i in range(3):
                    el = torch.nn.EmbeddingBag(1000000, 4, mode='sum', sparse=True)
                    self.embedding_layers.append(el)
                el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True)
                self.embedding_layers.append(el)

            def forward(self, a, b, offset):
                x = self.bottom_layers(a)
                y = []
                c = []
                for i in range(len(self.embedding_layers)):
                    temp = torch.randint(10, (8, ))
                    c.append(temp + b)
                for i in range(len(self.embedding_layers)):
                    if i % 2 == 0:
                        y.append(self.embedding_layers[i](c[i], offset))
                    else:
                        y.append(self.embedding_layers[i](torch.randint(10, (8, )), offset))
                z = torch.cat([x] + y, dim=1)
                p = self.top_layers(z)
                return p

        m = MyRecommendationModule()
        a = torch.rand(2, 4)
        b = torch.randint(10, (8, ))
        offset = torch.randint(1, (2, ))
        traced = symbolic_trace(m)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b, offset])
        devices = [
            Device('dev_0', 33000000, 0),
            Device('dev_1', 33000000, 1),
            Device('dev_2', 33000000, 2)
        ]
        partitioner_config = PartitionerConfig(devices, is_sparse_nn=True)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
        assert len(module_with_submodules.graph.nodes) == 24
Exemplo n.º 7
0
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            is_sparse_nn=False,
            is_cost_aware=True,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.
Exemplo n.º 8
0
    def test_partition_latency(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + linear_1
                add_4 = add_2 + add_3
                return add_4

        def get_node_to_latency_mapping(fx_module: GraphModule):
            """Given a fx module, generate node latency for each node
            based on the size of each node
            """
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {"output", "placeholder", "get_attr"}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
                        )
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, node.size_bytes.output_size
                        )
            return node_to_latency_mapping

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
        partitioner = Partitioner()
        partitioner_config = PartitionerConfig(devices, False)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping
        )
        for p in partition_to_latency_mapping:
            if p.partition_id == 0:
                assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
            else:
                assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
        transfer_rate_bytes_per_sec = 0.5
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
        )
        assert critical_path_latency_sec == 208.0
Exemplo n.º 9
0
    def test_partition_latency(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_4 = add_2 + linear_1
                add_5 = add_2 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            """Given a fx module, generate node latency for each node
               based on the size of each node
            """
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            2. * node.size_bytes.total_size)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        devices = [Device('dev_0', 200, 0), Device('dev_1', 200, 0)]
        partitioner = Partitioner()
        partitioner_config = PartitionerConfig(devices, False)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_latency_0 = get_latency_of_one_partition(
            partitions[0], node_to_latency_mapping)
        assert (128., 80., 160.) == partition_latency_0
        partition_latency_1 = get_latency_of_one_partition(
            partitions[1], node_to_latency_mapping)
        assert (16., 32., 32) == partition_latency_1
Exemplo n.º 10
0
 def test_get_all_users_of(self):
     graph: torch.fx.Graph = torch.fx.Graph()
     a: torch.fx.Node = graph.create_node('placeholder', 'x')
     b: torch.fx.Node = graph.create_node('call_module',
                                          'linear_mod',
                                          args=(a, ))
     c: torch.fx.Node = graph.create_node('get_attr', 'y_attr')
     d: torch.fx.Node = graph.create_node('call_function',
                                          operator.add,
                                          args=(b, c))
     graph.output(d)
     linear_mod: torch.nn.Module = torch.nn.Linear(3, 4)
     add_param: torch.Tensor = torch.rand(3, 4)
     gm: torch.fx.GraphModule = torch.fx.GraphModule(
         {
             'linear_mod': linear_mod,
             'y_attr': add_param
         }, graph)
     expected_uses: Dict[int, List[int]] = {
         0: [1],
         1: [3],
         2: [3],
         3: [4],
         4: [],
     }
     for i, node in enumerate(graph.nodes):
         user_indexes = GraphManipulation.get_all_users_of(gm, i)
         assert user_indexes == expected_uses[i]
Exemplo n.º 11
0
 def test_replace_target_nodes_with(self):
     class testModule(torch.nn.Module):
         def forward(self, a, b):
             return a + b
     m = testModule()
     traced = symbolic_trace(m)
     input1 = torch.randn(1)
     input2 = torch.randn(1)
     assert (input1 + input2) == traced(input1, input2)
     GraphManipulation.replace_target_nodes_with(
         fx_module=traced,
         old_op="call_function",
         old_target=operator.add,
         new_op="call_function",
         new_target=operator.mul,
     )
     assert (input1 * input2) == traced(input1, input2)
Exemplo n.º 12
0
    def test_find_single_partition(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(1)
        b = torch.rand(1)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125),
            Device('dev_1', 125),
            Device('dev_2', 125)
        ]
        ret = partitioner.partition_graph(traced, m, devices)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
Exemplo n.º 13
0
    def test_find_single_partition(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(1)
        b = torch.rand(1)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2)
        ]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert dag.nodes[0].logical_device_ids == [0]
Exemplo n.º 14
0
 def get_output_nodes(self) -> List[Node]:
     """Output nodes are the nodes that without any user inside this partition."""
     output_nodes: List[Node] = []
     for node in self.nodes:
         index = self.graph_module.graph.nodes.index(node)
         user_indexes = GraphManipulation.get_all_users_of(
             self.graph_module, index)
         user_nodes = {
             self.graph_module.graph.nodes[i]
             for i in user_indexes
         }
         # check if user nodes has an intersection with self.nodes
         if not set(self.nodes).intersection(user_nodes):
             output_nodes.append(node)
     return output_nodes
Exemplo n.º 15
0
    def test_serialize_graph(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.e = torch.rand(4)
                self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)

            def forward(self, a, b, c):
                add_1 = a + b
                conv1 = self.conv(c)
                linear = self.linear(add_1 + conv1)
                add_2 = linear + self.e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        c = torch.rand(3, 3, 2, 2)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b, c])

        partitioner = Partitioner()
        devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
        partitioner_config = PartitionerConfig(devices, is_sparse_nn=True)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        # Fix for now to add type/shape to output
        for node in traced.graph.nodes:
            if node.op == "output":
                node.shape = a.shape
                node.dtype = a.dtype
        for mod in module_with_submodules.modules():
            if isinstance(mod, GraphModule):
                for node in mod.graph.nodes:
                    node.shape = a.shape
                    node.dtype = a.dtype
        for node in module_with_submodules.graph.nodes:
            node.shape = a.shape
            node.dtype = a.dtype

        agm1 = GraphManipulation.AcceleratedGraphModule(traced)
        agm2 = GraphManipulation.AcceleratedGraphModule(module_with_submodules)
        assert len(agm1.weights) == 4
        assert len(agm2.weights) == 4
        assert len(agm1.serialized_graph["nodes"]) == 10
        assert len(agm1.serialized_graph["weights"]) == 4
        assert len(agm1.serialized_graph["modules"]) == 0
        assert len(agm2.serialized_graph["nodes"]) == 6
        assert len(agm2.serialized_graph["weights"]) == 4
        assert len(agm2.serialized_graph["modules"]) == 1
        assert agm1.serialized_graph["weights"]["linear.weight"][
            "shape"] == "[4, 4]"
        assert (agm1.serialized_graph["weights"]["linear.weight"]["dtype"] ==
                "torch.float32")
        assert (
            agm1.serialized_graph["weights"]["linear.weight"]["is_quantized"]
            is False)
        assert agm1.serialized_graph["nodes"][0]["shape"] == "[4]"
        assert agm1.serialized_graph["nodes"][0]["dtype"] == "torch.float32"
        assert agm1.serialized_graph["nodes"][0]["target"] == "a"
        assert agm1.serialized_graph["nodes"][0]["op_code"] == "placeholder"
        assert agm1.serialized_graph["nodes"][0]["name"] == "a"
        assert agm1.serialized_graph["nodes"][6]["args"][0]["name"] == "add_2"
        assert agm1.serialized_graph["nodes"][6]["args"][0]["is_node"] is True

        # Test quantization info serialization.
        x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
        q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
        q_tensor_channel = torch.quantize_per_channel(
            x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0,
            torch.quint8)
        result = GraphManipulation.serialize_tensor_quantization(q_tensor)
        result2 = GraphManipulation.serialize_tensor_quantization(
            q_tensor_channel)
        assert result["q_scheme"] == "torch.per_tensor_affine"
        assert result["q_scale"] == 1.0
        assert result2["q_scheme"] == "torch.per_channel_affine"
        assert len(result2["q_per_channel_scales"]) == 2