def calculate_mem_bytes_needed(p1, p2): """Given two partitions, calculate how many mem bytes are needed if two partitions are combined """ nodes = p1.nodes.union(p2.nodes) mem_bytes_needed = 0 for node in nodes: mem_bytes_needed += get_extra_size_of(node, nodes) return mem_bytes_needed
def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[Partition]): all_nodes: Set[Node] = set() for p in partitions: all_nodes = all_nodes.union(p.nodes) extra_size_needed = 0 for node in partition.nodes: if node in all_nodes or node.op in {'placeholder', 'get_attr'}: continue else: extra_size_needed += get_extra_size_of(node, all_nodes) return extra_size_needed
def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[Partition]): all_nodes: Set[Node] = set() for p in partitions: all_nodes = all_nodes.union(p.nodes) if len(all_nodes) == 0: return partition.used_mem_bytes all_nodes = all_nodes.union(partition.nodes) extra_size_needed = 0 for node in partition.nodes: extra_size_needed += get_extra_size_of(node, all_nodes) return extra_size_needed
def find_device_based_on_size(node) -> Device: """Given a node, this function is to find a logical device that could fit the node. """ mem_size_needed = get_extra_size_of(node, set()) device = Device('', -1, -1) for d in self.devices: if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed: device = d break if device.available_mem_bytes < 0: raise RuntimeError(str(node) + 'is too large to fit any device') occupied_devices.append(device) return device
def sparse_nn_partition(self, available_mem_bytes: int) -> None: """This method partition a sparse nn module. It is size based partition but different from size_based_partition, it only works when all the devices have same memory size (available_mem_bytes). In the future, devices with different mem sizes will be supported like size_based_partition. It first traverse all the nodes and do the partitions based on the same memory size. If the current partition has no enough memory left for a new op node (call_module, call_method, call_function), a new partition is created. When crossing the boundary between non-embedding nodes and embedding nodes, a new partition is created regardlessly. For example, if the current node is a non-embedding node but the next node is an embedding node, a new partition is created for the next node. After the partition, the partitions are combined as much as possible. The rule is that a non-embedding partition only combines with another non-embedding one. So as the embedding partitions. """ def combine_partitions_based_on_size(partitions: List[Partition], available_mem_bytes: int) -> None: """Combining small partitions together to keep as less partitions as possible. Here is an example of the algorithm to do this: Assume some partitions, we first sort them based on partiiton used memory size. [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] The available memory is 10. step 1: self.find_partition_to_combine_based_on_size() First, mark bfs level for each partition Second, look the smallest partition, partition_4: 10 - 1 = 9 It means any partition has a used memory equal or less than 9 could combine this partition We go from the largest and selection partition_0. Check the bfs level for two partitions, if the level difference is less than 2, it can be combined. step 2: repeat step 1 until no partitions can be combined """ find_combination = True while find_combination: # Sort partitions based on memory size sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) # Mark bfs level get_bfs_level_partition(self.partitions) find_combination, partitions = \ find_partition_to_combine_based_on_size( sorted_partitions, available_mem_bytes, partitions ) return def calculate_mem_bytes_needed(p1, p2): """Given two partitions, calculate how many mem bytes are needed if two partitions are combined """ nodes = p1.nodes.union(p2.nodes) mem_bytes_needed = 0 for node in nodes: mem_bytes_needed += get_extra_size_of(node, nodes) return mem_bytes_needed def find_partition_to_combine_based_on_size( sorted_partitions: List[Partition], available_mem_bytes: int, partitions: List[Partition]) -> Tuple[bool, List[Partition]]: """step 1 in combine_partition_based_on_size()""" find_combination = False smallest_partition = sorted_partitions.pop(0) for p in sorted_partitions[::-1]: if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: # Calculate how many bytes needed if combined mem_bytes_needed = calculate_mem_bytes_needed( p, smallest_partition) if mem_bytes_needed <= available_mem_bytes: combine_two_partitions(p, smallest_partition, self.partitions) partitions.remove(smallest_partition) partitions.remove(p) partitions.append(self.partitions[-1]) find_combination = True break return find_combination, partitions def reset_partition_in_sparse_nn(partition, new_partition=True): """If crossing the boudary between non-embedding nodes and embedding nodes, create a new partition """ if in_embedding_region: embedding_partitions.append(partition) else: non_embedding_partitions.append(partition) if new_partition: partition = self.create_partition() partition.left_mem_bytes = available_mem_bytes return partition return None def is_embedding_node(node: Node) -> bool: """Check if a node is an embedding node""" if node.op == 'call_module': submodule = self.graph_module for atom in str(node.target).split('.'): if not hasattr(submodule, atom): raise RuntimeError( f'Module {submodule} has no attribute {atom}') submodule = getattr(submodule, atom) if 'Embedding' in str(submodule): return True return False # Track embedding partitons and non-embedding partitions separately embedding_partitions: List[Partition] = [] non_embedding_partitions: List[Partition] = [] # A Flag to check the boundary in_embedding_region: bool = False partition = self.create_partition() for node in self.graph_module.graph.nodes: if node.op in {'call_module', 'call_method', 'call_function'}: # Check if crossing the boundary between embedding nodes and non embedding nodes if is_embedding_node(node) != in_embedding_region: # Crossing the boundary # Check if the current partition is an empty partition if partition.used_mem_bytes != 0: # The current partition isn't an empty partition. Create a new one. partition = reset_partition_in_sparse_nn(partition) in_embedding_region = not in_embedding_region total_size_of_input_nodes = get_extra_size_of( node, partition.nodes) if total_size_of_input_nodes + partition.used_mem_bytes > available_mem_bytes: partition = reset_partition_in_sparse_nn(partition) total_size_of_input_nodes = get_extra_size_of( node, partition.nodes) if total_size_of_input_nodes > available_mem_bytes: raise RuntimeError(node.target + 'is too large to fit into a device') partition.add_node(node) reset_partition_in_sparse_nn(partition, new_partition=False) # Set parents and children for partitions set_parents_and_children(self.partitions) # Combining non-embedding partitions combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) # Combining embedding partitions combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) total_size_of_non_embedding_partitions = 0 for partition in non_embedding_partitions: total_size_of_non_embedding_partitions += partition.used_mem_bytes # Check if devices are enough for all partitions if len(embedding_partitions) > len(self.devices): msg = 'Need ' + str(len(embedding_partitions)) + ' devices, but only ' \ + str(len(self.devices)) + ' provided' raise RuntimeError(msg) occupied_devices = [] for i, partition in enumerate(embedding_partitions): # Check if all non-embedding partitions can fit into embedding partition devices if total_size_of_non_embedding_partitions + partition.used_mem_bytes > available_mem_bytes: raise RuntimeError( 'partition_' + str(partition.partition_id) + '(embedding partition) and non embedding partitions can not fit into one device' ) else: # Add logical device to the partition partition.logical_device_ids = [self.devices[i].logical_id] occupied_devices.append(self.devices[i].logical_id) # Add logical devices to the non_embedding_partitions for partition in non_embedding_partitions: partition.logical_device_ids = occupied_devices # Get the node to partition mapping self.node_to_partition = get_node_to_partition_mapping(self.partitions) return
def size_based_partition(self) -> None: """This method is to partition the fx module based on memory size. It uses greedy approach. The result may not be the best. The basic idea is: Step 1: Find a device which has enough memory to fit the current node, create a empty partition with the size of that device. Then keep adding the following nodes into the partition until the partition is full. Step 2: Repeat Step 1 until no device left Step 3: If some nodes are left, create a partition for each left node (single node partition). and then try to map those partitions into logical devices with enough mem left. """ def find_device_based_on_size(node) -> Device: """Given a node, this function is to find a logical device that could fit the node. """ mem_size_needed = get_extra_size_of(node, set()) device = Device('', -1, -1) for d in self.devices: if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed: device = d break if device.available_mem_bytes < 0: raise RuntimeError( str(node) + 'is too large to fit any device') occupied_devices.append(device) return device # Track partition and its left mem size partition_to_left_mem_bytes: Dict[Partition, int] = {} # Track all the devices that have been used occupied_devices: List[Device] = [] partition = self.create_partition() for node in self.graph_module.graph.nodes: if node.op in {'call_module', 'call_method', 'call_function'}: # Check if there are devices left if len(self.partitions) <= len(self.devices): total_size_of_input_nodes = get_extra_size_of( node, partition.nodes) # Check if the current partition is the very first partition if partition.used_mem_bytes == 0: # Find a device to fit the first node, return available mem size device = find_device_based_on_size(node) occupied_devices.append(device) # Update partition and its left mem size partition_to_left_mem_bytes[ partition] = device.available_mem_bytes # Update available mem for the current partitio partition.logical_device_ids.append(device.logical_id) else: # The current partition is not the first partition # Check if the current node can fit into current partition if partition_to_left_mem_bytes[ partition] < total_size_of_input_nodes: # Check if no device is left if len(self.partitions) == len(self.devices): # No device is left # Put the previous partitions into a list (non_single_node_partitions) non_single_node_partitions = self.partitions[:] # Create the first single node partition for the current node self.create_single_node_partition(node) continue # Some devices are still left # Create a new partition with a mem size that is enough for the current node device = find_device_based_on_size(node) partition = self.create_partition() total_size_of_input_nodes = get_extra_size_of( node, partition.nodes) partition_to_left_mem_bytes[ partition] = device.available_mem_bytes partition.logical_device_ids.append( device.logical_id) partition.add_node(node) partition_to_left_mem_bytes[ partition] -= total_size_of_input_nodes # Create single node partitions if no device is left else: self.create_single_node_partition(node) reorganize_partitions(self.partitions) # Get the node to partition mapping self.node_to_partition = get_node_to_partition_mapping(self.partitions) # Mapping all partitions into device found_partition_to_device_mapping = get_device_to_partitions_mapping( self.partitions, self.devices) if not found_partition_to_device_mapping: raise RuntimeError( "Cannot Get a Valid Partition to Logical Device Mapping") return