Esempio n. 1
0
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            mode=PartitionMode.cost_aware,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.

        def test_kl_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.linear = torch.nn.Linear(4, 4)
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)
                    self.d = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = add_1 + self.c
                    linear_1 = self.linear(add_1)
                    add_3 = add_2 + linear_1
                    add_4 = add_2 + self.d
                    add_5 = add_3 + add_4
                    return add_4

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            graph_manipulation.get_size_of_all_nodes(traced, [a])
            node_to_latency_mapping = get_node_to_latency_mapping(traced)
            transfer_rate_bytes_per_sec = 2
            devices = [
                Device('dev_0', 200, 0),
                Device('dev_1', 200, 1),
                Device('dev_2', 200, 2),
                Device('dev_3', 200, 3)
            ]
            partitioner = Partitioner()
            partitioner_config = PartitionerConfig(
                devices,
                mode=PartitionMode.kl_based,
                transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec,
                node_to_latency_mapping=node_to_latency_mapping)
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            self.assertEqual(traced(a), module_with_submodules(a))
            dag = ret.dag
            assert dag.nodes[0] == 176
            assert dag.nodes[1] == 112
            partition_to_latency_mapping = get_partition_to_latency_mapping(
                partitioner.partitions, node_to_latency_mapping)
            cost = get_latency_of_partitioned_graph(
                partitioner.partitions, partition_to_latency_mapping,
                transfer_rate_bytes_per_sec)
            assert cost == 208.

        def test_aot_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = self.c + add_1
                    return add_2

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            node_to_partition_id = {}
            partition_to_logical_devices = {}
            count = 0
            GraphManipulation.get_size_of_all_nodes(traced, [a])
            for node in traced.graph.nodes:
                if node.op not in {'placeholder', 'get_attr', 'output'}:
                    node_to_partition_id[node] = count
                    partition_to_logical_devices[count] = [0]
                    count += 1
            devices = [Device('dev_0', 200, 0)]
            partitioner_config = PartitionerConfig(
                devices=devices,
                mode=PartitionMode.aot_based,
                node_to_partition_mapping=node_to_partition_id,
                partition_to_logical_device_mapping=partition_to_logical_devices
            )
            partitioner = Partitioner()
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            dag = ret.dag
            self.assertEqual(module_with_submodules(a), traced(a))
            for node in dag.nodes:
                assert node.size_bytes == 48
                assert node.logical_device_ids == [0]

        def test_replace_target_nodes_with(self):
            class testModule(torch.nn.Module):
                def forward(self, a, b):
                    return a + b

            m = testModule()
            traced = symbolic_trace(m)
            input1 = torch.randn(1)
            input2 = torch.randn(1)
            assert (input1 + input2) == traced(input1, input2)
            graph_manipulation.replace_target_nodes_with(
                fx_module=traced,
                old_op="call_function",
                old_target=operator.add,
                new_op="call_function",
                new_target=operator.mul,
            )
            assert (input1 * input2) == traced(input1, input2)
Esempio n. 2
0
    def test_sparse_nn_partition(self):
        class MyRecommendationModule(torch.nn.Module):
            def create_mlp(self, num_of_layers: int, input_size: int,
                           output_size: int):
                layers = torch.nn.ModuleList()
                for _ in range(num_of_layers):
                    ll = torch.nn.Linear(input_size, output_size)
                    layers.append(ll)
                    layers.append(torch.nn.ReLU())
                return layers

            def __init__(self):
                super(MyRecommendationModule, self).__init__()
                layers = self.create_mlp(4, 4, 4)
                self.bottom_layers = torch.nn.Sequential(*layers)
                layers = self.create_mlp(3, 24, 24)
                self.top_layers = torch.nn.Sequential(*layers)
                self.embedding_layers = torch.nn.ModuleList()
                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
                self.embedding_layers.append(el)
                for i in range(3):
                    el = torch.nn.EmbeddingBag(1000000,
                                               4,
                                               mode="sum",
                                               sparse=True)
                    self.embedding_layers.append(el)
                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
                self.embedding_layers.append(el)

            def forward(self, a, b, offset):
                x = self.bottom_layers(a)
                y = []
                c = []
                for i in range(len(self.embedding_layers)):
                    temp = torch.randint(10, (8, ))
                    c.append(temp + b)
                for i in range(len(self.embedding_layers)):
                    if i % 2 == 0:
                        y.append(self.embedding_layers[i](c[i], offset))
                    else:
                        y.append(self.embedding_layers[i](torch.randint(
                            10, (8, )), offset))
                z = torch.cat([x] + y, dim=1)
                p = self.top_layers(z)
                return p

        m = MyRecommendationModule()
        a = torch.rand(2, 4)
        b = torch.randint(10, (8, ))
        offset = torch.randint(1, (2, ))
        traced = symbolic_trace(m)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
        devices = [
            Device("dev_0", 33000000, 0),
            Device("dev_1", 33000000, 1),
            Device("dev_2", 33000000, 2)
        ]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.sparse_nn)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b, offset),
                         module_with_submodules(a, b, offset))
        assert len(module_with_submodules.graph.nodes) == 24
 def partition_graph(
         self, fx_module: GraphModule, torch_module: torch.nn.Module,
         partitioner_config: PartitionerConfig) -> PartitionResult:
     """Given the fx module, torch module and partitioner_config,
        find the partitions, do the partitions,
        and then return a DAG and a new fx module with submodule nodes (partitions)
     """
     self.graph_module = fx_module
     self.torch_module = torch_module
     self.devices = partitioner_config.devices
     if len(self.devices) == 0:
         raise RuntimeError('No devices')
     # Tag the size in bytes to all nodes in the graph_module.
     get_size_of_all_nodes(self.graph_module)
     # Check if there are op nodes in the fx module
     nodes = self.graph_module.graph.nodes
     if all(node.op in {'placeholder', 'get_attr', 'output'}
            for node in nodes):
         raise RuntimeError(
             'No Partition since no operations in the module')
     # Calculate total size of the fx module
     total_size_of_graph = 0
     for node in nodes:
         if node.op == 'output':
             break
         total_size_of_graph += node.size_bytes.total_size
     # Find the device with the max mem size
     device_with_max_mem = max(self.devices,
                               key=lambda d: d.available_mem_bytes)
     # AOT based partition
     if partitioner_config.mode == PartitionMode.aot_based:
         self.aot_based_partition(
             partitioner_config.node_to_partition_mapping,
             partitioner_config.partition_to_logical_device_mapping)
     # Single partition if the whole module can be fit into one device
     elif total_size_of_graph <= device_with_max_mem.available_mem_bytes:
         self.find_single_partition(total_size_of_graph)
     elif total_size_of_graph > sum(
         [d.available_mem_bytes for d in self.devices]):
         raise RuntimeError('Devices have no enough memory for the module')
     else:
         # Sparse nn based partition
         if partitioner_config.mode == PartitionMode.sparse_nn:
             available_mem_bytes = self.devices[0].available_mem_bytes
             if not all(device.available_mem_bytes == available_mem_bytes
                        for device in self.devices):
                 raise RuntimeError(
                     'All devices must have same memory size!')
             # sparse_nn_partition only support same memory size
             # TODO: add different size support for sparse_nn_partition
             self.sparse_nn_partition(available_mem_bytes)
         # Cost aware partition
         elif partitioner_config.mode == PartitionMode.cost_aware:
             self.cost_aware_partition(
                 partitioner_config.transfer_rate_bytes_per_sec,
                 partitioner_config.node_to_latency_mapping)
         # KL based partition
         elif partitioner_config.mode == PartitionMode.kl_based:
             self.kl_based_partition(
                 partitioner_config.transfer_rate_bytes_per_sec,
                 partitioner_config.node_to_latency_mapping)
         else:
             self.size_based_partition()
     module_with_submodules = self.do_partition()
     # The DAG contains DAGNodes with info of each partition's input nodes, output nodes
     # and how partitions are connected.
     dag = self.dump_dag(module_with_submodules)
     ret = PartitionResult(dag, module_with_submodules)
     return ret
    def test_serialize_graph(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.e = torch.rand(4)
                self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)

            def forward(self, a, b, c):
                add_1 = a + b
                conv1 = self.conv(c)
                linear = self.linear(add_1 + conv1)
                add_2 = linear + self.e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        c = torch.rand(3, 3, 2, 2)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])

        partitioner = Partitioner()
        devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.sparse_nn)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        # Fix for now to add type/shape to output
        for node in traced.graph.nodes:
            if node.op == "output":
                node.shape = a.shape
                node.dtype = a.dtype
        for mod in module_with_submodules.modules():
            if isinstance(mod, GraphModule):
                for node in mod.graph.nodes:
                    node.shape = a.shape
                    node.dtype = a.dtype
        for node in module_with_submodules.graph.nodes:
            node.shape = a.shape
            node.dtype = a.dtype

        weights1 = {}
        weights2 = {}
        serialized_graph1 = graph_manipulation.serialize_module(
            traced, weights1)
        serialized_graph2 = graph_manipulation.serialize_module(
            module_with_submodules, weights2)
        assert len(weights1) == 4
        assert len(weights2) == 4
        assert len(serialized_graph1["nodes"]) == 10
        assert len(serialized_graph1["weights"]) == 4
        assert len(serialized_graph1["modules"]) == 0
        assert len(serialized_graph2["nodes"]) == 6
        assert len(serialized_graph2["weights"]) == 4
        assert len(serialized_graph2["modules"]) == 1
        assert serialized_graph1["weights"]["linear.weight"][
            "shape"] == "[4, 4]"
        assert (serialized_graph1["weights"]["linear.weight"]["dtype"] ==
                "torch.float32")
        assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"]
                is False)
        assert serialized_graph1["nodes"][0]["shape"] == "[4]"
        assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32"
        assert serialized_graph1["nodes"][0]["target"] == "a"
        assert serialized_graph1["nodes"][0]["op_code"] == "placeholder"
        assert serialized_graph1["nodes"][0]["name"] == "a"
        assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_2"
        assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True

        # Test quantization info serialization.
        x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
        q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
        q_tensor_channel = torch.quantize_per_channel(
            x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0,
            torch.quint8)
        result = graph_manipulation.serialize_tensor_quantization(q_tensor)
        result2 = graph_manipulation.serialize_tensor_quantization(
            q_tensor_channel)
        assert result["q_scheme"] == "torch.per_tensor_affine"
        assert result["q_scale"] == 1.0
        assert result2["q_scheme"] == "torch.per_channel_affine"
        assert len(result2["q_per_channel_scales"]) == 2
Esempio n. 5
0
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            is_sparse_nn=False,
            is_cost_aware=True,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.

        def test_replace_target_nodes_with(self):
            class testModule(torch.nn.Module):
                def forward(self, a, b):
                    return a + b

            m = testModule()
            traced = symbolic_trace(m)
            input1 = torch.randn(1)
            input2 = torch.randn(1)
            assert (input1 + input2) == traced(input1, input2)
            graph_manipulation.replace_target_nodes_with(
                fx_module=traced,
                old_op="call_function",
                old_target=operator.add,
                new_op="call_function",
                new_target=operator.mul,
            )
            assert (input1 * input2) == traced(input1, input2)