def try_combining_partitions(p0_index, p1_index, partitions) -> float: """Given two partitions and a list of partitions, combine these two partitions and see what is the cost of the modified partition list """ p0 = partitions[p0_index] p1 = partitions[p1_index] """If two partitions' bfs level are less than 2 or two partitions are connected to each other, then they can be combined """ if (abs(p0.bfs_level - p1.bfs_level) <= 1) or (p0 in p1.parents) or p0 in (p1.children): combine_two_partitions(p0, p1, partitions) # Check if a circular dependency exists after combining if check_dependency(partitions[-1]): return float('inf') # Check if the modified partition list can be mapped to devices after combination reset_partition_device(partitions) found_deivce = get_device_to_partitions_mapping( partitions, self.devices) if not found_deivce: return float('inf') # Calculate the new cost partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) cost = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) return cost # If two partition can not be combined, the cost is inf return float('inf')
def try_swap_nodes(n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec): cost = float('inf') swap_nodes(n0, n1, p0, p1) # Reorganize partitions after swapping reorganize_partitions(self.partitions) # Check if there is a circular dependency after swapping if (not check_dependency(p0)) and (not check_dependency(p1)): reset_partition_device(self.partitions) partition_to_latency_mapping = get_partition_to_latency_mapping( self.partitions, node_to_latency_mapping) # Check if all partitions can be mapped to logical devices after swapping found_device = get_device_to_partitions_mapping( self.partitions, self.devices) if not found_device: cost = float('inf') else: cost = get_latency_of_partitioned_graph( self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) # Swap back and reset all partitions back to original swap_nodes(n1, n0, p0, p1) reorganize_partitions(self.partitions) reset_partition_device(self.partitions) get_device_to_partitions_mapping(self.partitions, self.devices) return cost
def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, Nodelatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2), Device('dev_3', 125, 3) ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, is_sparse_nn=False, is_cost_aware=True, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec) assert critical_path_latency_sec == 160.
def test_partition_latency(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + add_3 return add_4 def get_node_to_latency_mapping(fx_module: GraphModule): """Given a fx module, generate node latency for each node based on the size of each node """ node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {"output", "placeholder", "get_attr"}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 2.0 * node.size_bytes.total_size ) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size ) return node_to_latency_mapping m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] partitioner = Partitioner() partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping ) for p in partition_to_latency_mapping: if p.partition_id == 0: assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) else: assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) transfer_rate_bytes_per_sec = 2 critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) assert critical_path_latency_sec == 208.0
def test_kl_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) self.b = torch.rand(4) self.c = torch.rand(4) self.d = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = add_1 + self.c linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + self.d add_5 = add_3 + add_4 return add_4 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) transfer_rate_bytes_per_sec = 2 devices = [ Device('dev_0', 200, 0), Device('dev_1', 200, 1), Device('dev_2', 200, 2), Device('dev_3', 200, 3) ] partitioner = Partitioner() partitioner_config = PartitionerConfig( devices, mode=PartitionMode.kl_based, transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec, node_to_latency_mapping=node_to_latency_mapping ) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) dag = ret.dag assert dag.nodes[0] == 176 assert dag.nodes[1] == 112 partition_to_latency_mapping = get_partition_to_latency_mapping( partitioner.partitions, node_to_latency_mapping ) cost = get_latency_of_partitioned_graph( partitioner.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) assert cost == 208.
def search_combination( transfer_rate_bytes_per_sec, node_to_latency_mapping ) -> bool: """Given transfer rate between partitions and each node's latency, find two partitions to combine so the cost of the partitions can be reduced. The algorithm is : 1. Go through all the partition pairs and see if any pair of partitions can be combined. 2. Calculate the cost after the combination. 3. Select the minimum cost and combine its cooresponding partition pair. """ partition_to_latency_mapping = get_partition_to_latency_mapping(self.partitions, node_to_latency_mapping) cost = get_latency_of_partitioned_graph(self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) if len(self.partitions) == 1: return False partition_pair: List[int] = [] for i in range(len(self.partitions) - 1): for j in range(i + 1, len(self.partitions)): # Try to combine the partition pair # and see the new cost after combination new_cost = try_combining_partitions( i, j, self.partitions[:] ) if new_cost <= cost: partition_pair = [i, j] cost = new_cost reorganize_partitions(self.partitions) # If a partition pair is found, combine them if len(partition_pair) != 0: p0 = self.partitions[partition_pair[0]] p1 = self.partitions[partition_pair[1]] combine_two_partitions(p0, p1, self.partitions) get_bfs_level_partition(self.partitions) reset_partition_device(self.partitions) get_device_to_partitions_mapping(self.partitions, self.devices) return len(partition_pair) != 0
def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, Nodelatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2), Device('dev_3', 125, 3) ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, mode=PartitionMode.cost_aware, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec) assert critical_path_latency_sec == 160. def test_kl_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) self.b = torch.rand(4) self.c = torch.rand(4) self.d = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = add_1 + self.c linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + self.d add_5 = add_3 + add_4 return add_4 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) transfer_rate_bytes_per_sec = 2 devices = [ Device('dev_0', 200, 0), Device('dev_1', 200, 1), Device('dev_2', 200, 2), Device('dev_3', 200, 3) ] partitioner = Partitioner() partitioner_config = PartitionerConfig( devices, mode=PartitionMode.kl_based, transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec, node_to_latency_mapping=node_to_latency_mapping) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) dag = ret.dag assert dag.nodes[0] == 176 assert dag.nodes[1] == 112 partition_to_latency_mapping = get_partition_to_latency_mapping( partitioner.partitions, node_to_latency_mapping) cost = get_latency_of_partitioned_graph( partitioner.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) assert cost == 208. def test_aot_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.b = torch.rand(4) self.c = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = self.c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) node_to_partition_id = {} partition_to_logical_devices = {} count = 0 GraphManipulation.get_size_of_all_nodes(traced, [a]) for node in traced.graph.nodes: if node.op not in {'placeholder', 'get_attr', 'output'}: node_to_partition_id[node] = count partition_to_logical_devices[count] = [0] count += 1 devices = [Device('dev_0', 200, 0)] partitioner_config = PartitionerConfig( devices=devices, mode=PartitionMode.aot_based, node_to_partition_mapping=node_to_partition_id, partition_to_logical_device_mapping=partition_to_logical_devices ) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(module_with_submodules(a), traced(a)) for node in dag.nodes: assert node.size_bytes == 48 assert node.logical_device_ids == [0] def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) graph_manipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2)
def kl_based_partition( self, transfer_rate_bytes_per_sec: float, node_to_latency_mapping: Dict[Node, NodeLatency]) -> None: """This function is a cost aware partition based on Kernighan-Lin algorithm. First, the graph is partitioned using size_based_partition. Then, each node is swapped with any other node in a different partition, and at the same time, the cost is estimated after the swapping. For example, we have nodes n0, n1, n2, n3 and n4. Using size_based_partition, n0 and n1 are in Partition p0. n2, n3 and n4 in Partition p1. The current cost is esimated. We first tried using n0 to swap with n2 from the other partiton. Then we see that swapping n0 and n2 shows a lower cost than the current cost and it is the minimum among other pairs like (n0, None)(This means moving n0 to Partition without swapping other nodes), (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost as the current cost. Then We repeat this process for all the other nodes until all swapping pairs are tried. """ def swap_nodes(n0, n1, p0, p1): # Either n0 or n1 could be None # That means we simply move the node # to another partition if n0 is not None: p0.remove_node(n0) p1.add_node(n0) if n1 is not None: p0.add_node(n1) p1.remove_node(n1) def try_swap_nodes(n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec): cost = float('inf') swap_nodes(n0, n1, p0, p1) # Reorganize partitions after swapping reorganize_partitions(self.partitions) # Check if there is a circular dependency after swapping if (not check_dependency(p0)) and (not check_dependency(p1)): reset_partition_device(self.partitions) partition_to_latency_mapping = get_partition_to_latency_mapping( self.partitions, node_to_latency_mapping) # Check if all partitions can be mapped to logical devices after swapping found_device = get_device_to_partitions_mapping( self.partitions, self.devices) if not found_device: cost = float('inf') else: cost = get_latency_of_partitioned_graph( self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) # Swap back and reset all partitions back to original swap_nodes(n1, n0, p0, p1) reorganize_partitions(self.partitions) reset_partition_device(self.partitions) get_device_to_partitions_mapping(self.partitions, self.devices) return cost def swap_node_to_partition(node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec): """This function helps to swap one node from partition p0 with all the nodes in another partition p1 """ p1_nodes = list(p1.nodes) + [None] min_cost = float('inf') node_pair: List[Node] = [] for n1 in p1_nodes: # Ignore the node if it is not a op node if n1 is not None and n1.op in {'placeholder', 'get_attr'}: continue # Try swapping node in p0 with n1 in p1 cost = try_swap_nodes(node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec) if cost < min_cost: node_pair = [node, n1] min_cost = cost return cost, node_pair # First use size_base_partition self.size_based_partition() partition_to_latency_mapping = get_partition_to_latency_mapping( self.partitions, node_to_latency_mapping) # Calculate the cost of the partitions cost = get_latency_of_partitioned_graph(self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) # Keep tracking the node pair that shows the better cost node_pair: List[Node] = [] # Keep tracking the partition pair of node pair partition_pair: List[Partition] = [] # Collect all the op nodes from the graph op_nodes = [] for n in self.graph_module.graph.nodes: if n.op not in {'placeholder', 'get_attr', 'output'}: op_nodes.append(n) for node in op_nodes: # Find which partition the current node belongs p0_index = self.node_to_partition[node] p0 = self.partitions[p0_index] # Go through all the other partitions to swap # with other nodes from those partitions for p1_index, _ in enumerate(self.partitions): if p0_index != p1_index: p1 = self.partitions[p1_index] new_cost, new_node_pair = swap_node_to_partition( node, p0, p1, node_to_latency_mapping, transfer_rate_bytes_per_sec) # Update the cost # Track the swapped node pair and their partitions if new_cost < cost: cost = new_cost node_pair = new_node_pair partition_pair = [p0, p1] # Do the swapping after trying all the nodes from a partition if len(node_pair) != 0: swap_nodes(node_pair[0], node_pair[1], partition_pair[0], partition_pair[1]) reorganize_partitions(self.partitions) get_device_to_partitions_mapping(self.partitions, self.devices) reorganize_partitions(self.partitions) # Mapping the device to the partition get_device_to_partitions_mapping(self.partitions, self.devices) return