def try_combining_partitions(p0_index, p1_index, partitions) -> float:
     """Given two partitions and a list of partitions, combine these two partitions
        and see what is the cost of the modified partition list
     """
     p0 = partitions[p0_index]
     p1 = partitions[p1_index]
     """If two partitions' bfs level are less than 2 or two partitions are connected to each other,
        then they can be combined
     """
     if (abs(p0.bfs_level - p1.bfs_level) <=
             1) or (p0 in p1.parents) or p0 in (p1.children):
         combine_two_partitions(p0, p1, partitions)
         # Check if a circular dependency exists after combining
         if check_dependency(partitions[-1]):
             return float('inf')
         # Check if the modified partition list can be mapped to devices after combination
         reset_partition_device(partitions)
         found_deivce = get_device_to_partitions_mapping(
             partitions, self.devices)
         if not found_deivce:
             return float('inf')
         # Calculate the new cost
         partition_to_latency_mapping = get_partition_to_latency_mapping(
             partitions, node_to_latency_mapping)
         cost = get_latency_of_partitioned_graph(
             partitions, partition_to_latency_mapping,
             transfer_rate_bytes_per_sec)
         return cost
     # If two partition can not be combined, the cost is inf
     return float('inf')
 def try_swap_nodes(n0, n1, p0, p1, node_to_latency_mapping,
                    transfer_rate_per_sec):
     cost = float('inf')
     swap_nodes(n0, n1, p0, p1)
     # Reorganize partitions after swapping
     reorganize_partitions(self.partitions)
     # Check if there is a circular dependency after swapping
     if (not check_dependency(p0)) and (not check_dependency(p1)):
         reset_partition_device(self.partitions)
         partition_to_latency_mapping = get_partition_to_latency_mapping(
             self.partitions, node_to_latency_mapping)
         # Check if all partitions can be mapped to logical devices after swapping
         found_device = get_device_to_partitions_mapping(
             self.partitions, self.devices)
         if not found_device:
             cost = float('inf')
         else:
             cost = get_latency_of_partitioned_graph(
                 self.partitions, partition_to_latency_mapping,
                 transfer_rate_bytes_per_sec)
     # Swap back and reset all partitions back to original
     swap_nodes(n1, n0, p0, p1)
     reorganize_partitions(self.partitions)
     reset_partition_device(self.partitions)
     get_device_to_partitions_mapping(self.partitions, self.devices)
     return cost
Beispiel #3
0
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            is_sparse_nn=False,
            is_cost_aware=True,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.
Beispiel #4
0
    def test_partition_latency(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + linear_1
                add_4 = add_2 + add_3
                return add_4

        def get_node_to_latency_mapping(fx_module: GraphModule):
            """Given a fx module, generate node latency for each node
            based on the size of each node
            """
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {"output", "placeholder", "get_attr"}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
                        )
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, node.size_bytes.output_size
                        )
            return node_to_latency_mapping

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
        partitioner = Partitioner()
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping
        )
        for p in partition_to_latency_mapping:
            if p.partition_id == 0:
                assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
            else:
                assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
        transfer_rate_bytes_per_sec = 2
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
        )
        assert critical_path_latency_sec == 208.0
Beispiel #5
0
        def test_kl_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.linear = torch.nn.Linear(4, 4)
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)
                    self.d = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = add_1 + self.c
                    linear_1 = self.linear(add_1)
                    add_3 = add_2 + linear_1
                    add_4 = add_2 + self.d
                    add_5 = add_3 + add_4
                    return add_4
            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            graph_manipulation.get_size_of_all_nodes(traced, [a])
            node_to_latency_mapping = get_node_to_latency_mapping(traced)
            transfer_rate_bytes_per_sec = 2
            devices = [
                Device('dev_0', 200, 0),
                Device('dev_1', 200, 1),
                Device('dev_2', 200, 2),
                Device('dev_3', 200, 3)
            ]
            partitioner = Partitioner()
            partitioner_config = PartitionerConfig(
                devices,
                mode=PartitionMode.kl_based,
                transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec,
                node_to_latency_mapping=node_to_latency_mapping
            )
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            self.assertEqual(traced(a), module_with_submodules(a))
            dag = ret.dag
            assert dag.nodes[0] == 176
            assert dag.nodes[1] == 112
            partition_to_latency_mapping = get_partition_to_latency_mapping(
                partitioner.partitions,
                node_to_latency_mapping
            )
            cost = get_latency_of_partitioned_graph(
                partitioner.partitions,
                partition_to_latency_mapping,
                transfer_rate_bytes_per_sec
            )
            assert cost == 208.
 def search_combination(
     transfer_rate_bytes_per_sec,
     node_to_latency_mapping
 ) -> bool:
     """Given transfer rate between partitions and each node's latency,
        find two partitions to combine so the cost of the partitions can
        be reduced.
        The algorithm is :
        1. Go through all the partition pairs and see
        if any pair of partitions can be combined.
        2. Calculate the cost after the combination.
        3. Select the minimum cost and combine its cooresponding partition pair.
     """
     partition_to_latency_mapping = get_partition_to_latency_mapping(self.partitions, node_to_latency_mapping)
     cost = get_latency_of_partitioned_graph(self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec)
     if len(self.partitions) == 1:
         return False
     partition_pair: List[int] = []
     for i in range(len(self.partitions) - 1):
         for j in range(i + 1, len(self.partitions)):
             # Try to combine the partition pair
             # and see the new cost after combination
             new_cost = try_combining_partitions(
                 i,
                 j,
                 self.partitions[:]
             )
             if new_cost <= cost:
                 partition_pair = [i, j]
                 cost = new_cost
             reorganize_partitions(self.partitions)
     # If a partition pair is found, combine them
     if len(partition_pair) != 0:
         p0 = self.partitions[partition_pair[0]]
         p1 = self.partitions[partition_pair[1]]
         combine_two_partitions(p0, p1, self.partitions)
     get_bfs_level_partition(self.partitions)
     reset_partition_device(self.partitions)
     get_device_to_partitions_mapping(self.partitions, self.devices)
     return len(partition_pair) != 0
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            mode=PartitionMode.cost_aware,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.

        def test_kl_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.linear = torch.nn.Linear(4, 4)
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)
                    self.d = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = add_1 + self.c
                    linear_1 = self.linear(add_1)
                    add_3 = add_2 + linear_1
                    add_4 = add_2 + self.d
                    add_5 = add_3 + add_4
                    return add_4

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            graph_manipulation.get_size_of_all_nodes(traced, [a])
            node_to_latency_mapping = get_node_to_latency_mapping(traced)
            transfer_rate_bytes_per_sec = 2
            devices = [
                Device('dev_0', 200, 0),
                Device('dev_1', 200, 1),
                Device('dev_2', 200, 2),
                Device('dev_3', 200, 3)
            ]
            partitioner = Partitioner()
            partitioner_config = PartitionerConfig(
                devices,
                mode=PartitionMode.kl_based,
                transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec,
                node_to_latency_mapping=node_to_latency_mapping)
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            self.assertEqual(traced(a), module_with_submodules(a))
            dag = ret.dag
            assert dag.nodes[0] == 176
            assert dag.nodes[1] == 112
            partition_to_latency_mapping = get_partition_to_latency_mapping(
                partitioner.partitions, node_to_latency_mapping)
            cost = get_latency_of_partitioned_graph(
                partitioner.partitions, partition_to_latency_mapping,
                transfer_rate_bytes_per_sec)
            assert cost == 208.

        def test_aot_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = self.c + add_1
                    return add_2

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            node_to_partition_id = {}
            partition_to_logical_devices = {}
            count = 0
            GraphManipulation.get_size_of_all_nodes(traced, [a])
            for node in traced.graph.nodes:
                if node.op not in {'placeholder', 'get_attr', 'output'}:
                    node_to_partition_id[node] = count
                    partition_to_logical_devices[count] = [0]
                    count += 1
            devices = [Device('dev_0', 200, 0)]
            partitioner_config = PartitionerConfig(
                devices=devices,
                mode=PartitionMode.aot_based,
                node_to_partition_mapping=node_to_partition_id,
                partition_to_logical_device_mapping=partition_to_logical_devices
            )
            partitioner = Partitioner()
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            dag = ret.dag
            self.assertEqual(module_with_submodules(a), traced(a))
            for node in dag.nodes:
                assert node.size_bytes == 48
                assert node.logical_device_ids == [0]

        def test_replace_target_nodes_with(self):
            class testModule(torch.nn.Module):
                def forward(self, a, b):
                    return a + b

            m = testModule()
            traced = symbolic_trace(m)
            input1 = torch.randn(1)
            input2 = torch.randn(1)
            assert (input1 + input2) == traced(input1, input2)
            graph_manipulation.replace_target_nodes_with(
                fx_module=traced,
                old_op="call_function",
                old_target=operator.add,
                new_op="call_function",
                new_target=operator.mul,
            )
            assert (input1 * input2) == traced(input1, input2)
    def kl_based_partition(
            self, transfer_rate_bytes_per_sec: float,
            node_to_latency_mapping: Dict[Node, NodeLatency]) -> None:
        """This function is a cost aware partition based
           on Kernighan-Lin algorithm.
           First, the graph is partitioned using size_based_partition.
           Then, each node is swapped with any other node in a different
           partition, and at the same time, the cost is estimated after
           the swapping.
           For example, we have nodes n0, n1, n2, n3 and n4.
           Using size_based_partition, n0 and n1 are in Partition p0.
           n2, n3 and n4 in Partition p1. The current cost is esimated.
           We first tried using n0 to swap with n2 from the other partiton.
           Then we see that swapping n0 and n2 shows a lower cost
           than the current cost and it is the minimum among other pairs like
           (n0, None)(This means moving n0 to Partition without swapping other nodes),
           (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost
           as the current cost.
           Then We repeat this process for all the other nodes until all swapping pairs
           are tried.
        """
        def swap_nodes(n0, n1, p0, p1):
            # Either n0 or n1 could be None
            # That means we simply move the node
            # to another partition
            if n0 is not None:
                p0.remove_node(n0)
                p1.add_node(n0)
            if n1 is not None:
                p0.add_node(n1)
                p1.remove_node(n1)

        def try_swap_nodes(n0, n1, p0, p1, node_to_latency_mapping,
                           transfer_rate_per_sec):
            cost = float('inf')
            swap_nodes(n0, n1, p0, p1)
            # Reorganize partitions after swapping
            reorganize_partitions(self.partitions)
            # Check if there is a circular dependency after swapping
            if (not check_dependency(p0)) and (not check_dependency(p1)):
                reset_partition_device(self.partitions)
                partition_to_latency_mapping = get_partition_to_latency_mapping(
                    self.partitions, node_to_latency_mapping)
                # Check if all partitions can be mapped to logical devices after swapping
                found_device = get_device_to_partitions_mapping(
                    self.partitions, self.devices)
                if not found_device:
                    cost = float('inf')
                else:
                    cost = get_latency_of_partitioned_graph(
                        self.partitions, partition_to_latency_mapping,
                        transfer_rate_bytes_per_sec)
            # Swap back and reset all partitions back to original
            swap_nodes(n1, n0, p0, p1)
            reorganize_partitions(self.partitions)
            reset_partition_device(self.partitions)
            get_device_to_partitions_mapping(self.partitions, self.devices)
            return cost

        def swap_node_to_partition(node, p0, p1, node_to_latency_mapping,
                                   transfer_rate_per_sec):
            """This function helps to swap one node from partition p0
               with all the nodes in another partition p1
            """
            p1_nodes = list(p1.nodes) + [None]
            min_cost = float('inf')
            node_pair: List[Node] = []
            for n1 in p1_nodes:
                # Ignore the node if it is not a op node
                if n1 is not None and n1.op in {'placeholder', 'get_attr'}:
                    continue
                # Try swapping node in p0 with n1 in p1
                cost = try_swap_nodes(node, n1, p0, p1,
                                      node_to_latency_mapping,
                                      transfer_rate_per_sec)
                if cost < min_cost:
                    node_pair = [node, n1]
                    min_cost = cost
            return cost, node_pair

        # First use size_base_partition
        self.size_based_partition()
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            self.partitions, node_to_latency_mapping)
        # Calculate the cost of the partitions
        cost = get_latency_of_partitioned_graph(self.partitions,
                                                partition_to_latency_mapping,
                                                transfer_rate_bytes_per_sec)
        # Keep tracking the node pair that shows the better cost
        node_pair: List[Node] = []
        # Keep tracking the partition pair of node pair
        partition_pair: List[Partition] = []
        # Collect all the op nodes from the graph
        op_nodes = []
        for n in self.graph_module.graph.nodes:
            if n.op not in {'placeholder', 'get_attr', 'output'}:
                op_nodes.append(n)
        for node in op_nodes:
            # Find which partition the current node belongs
            p0_index = self.node_to_partition[node]
            p0 = self.partitions[p0_index]
            # Go through all the other partitions to swap
            # with other nodes from those partitions
            for p1_index, _ in enumerate(self.partitions):
                if p0_index != p1_index:
                    p1 = self.partitions[p1_index]
                    new_cost, new_node_pair = swap_node_to_partition(
                        node, p0, p1, node_to_latency_mapping,
                        transfer_rate_bytes_per_sec)
                    # Update the cost
                    # Track the swapped node pair and their partitions
                    if new_cost < cost:
                        cost = new_cost
                        node_pair = new_node_pair
                        partition_pair = [p0, p1]
            # Do the swapping after trying all the nodes from a partition
            if len(node_pair) != 0:
                swap_nodes(node_pair[0], node_pair[1], partition_pair[0],
                           partition_pair[1])
                reorganize_partitions(self.partitions)
                get_device_to_partitions_mapping(self.partitions, self.devices)
        reorganize_partitions(self.partitions)
        # Mapping the device to the partition
        get_device_to_partitions_mapping(self.partitions, self.devices)
        return