コード例 #1
0
 def calculate_mem_bytes_needed(p1, p2):
     """Given two partitions, calculate how many mem bytes
        are needed if two partitions are combined
     """
     nodes = p1.nodes.union(p2.nodes)
     mem_bytes_needed = 0
     for node in nodes:
         mem_bytes_needed += get_extra_size_of(node, nodes)
     return mem_bytes_needed
コード例 #2
0
 def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[Partition]):
     all_nodes: Set[Node] = set()
     for p in partitions:
         all_nodes = all_nodes.union(p.nodes)
     extra_size_needed = 0
     for node in partition.nodes:
         if node in all_nodes or node.op in {'placeholder', 'get_attr'}:
             continue
         else:
             extra_size_needed += get_extra_size_of(node, all_nodes)
     return extra_size_needed
コード例 #3
0
 def calculate_extra_mem_bytes_needed_for(partition: Partition, partitions: List[Partition]):
     all_nodes: Set[Node] = set()
     for p in partitions:
         all_nodes = all_nodes.union(p.nodes)
     if len(all_nodes) == 0:
         return partition.used_mem_bytes
     all_nodes = all_nodes.union(partition.nodes)
     extra_size_needed = 0
     for node in partition.nodes:
         extra_size_needed += get_extra_size_of(node, all_nodes)
     return extra_size_needed
コード例 #4
0
 def find_device_based_on_size(node) -> Device:
     """Given a node, this function is to find a logical device
        that could fit the node.
     """
     mem_size_needed = get_extra_size_of(node, set())
     device = Device('', -1, -1)
     for d in self.devices:
         if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed:
             device = d
             break
     if device.available_mem_bytes < 0:
         raise RuntimeError(str(node) + 'is too large to fit any device')
     occupied_devices.append(device)
     return device
コード例 #5
0
    def sparse_nn_partition(self, available_mem_bytes: int) -> None:
        """This method partition a sparse nn module.
           It is size based partition but different from size_based_partition,
           it only works when all the devices have same memory size (available_mem_bytes).
           In the future, devices with different mem sizes will be supported like size_based_partition.
           It first traverse all the nodes and do the partitions based on the same memory size.
           If the current partition has no enough memory left for a new op node
           (call_module, call_method, call_function), a new partition is created.
           When crossing the boundary between non-embedding nodes and embedding nodes,
           a new partition is created regardlessly.
           For example, if the current node is a non-embedding node but the next node is an
           embedding node, a new partition is created for the next node.
           After the partition, the partitions are combined as much as possible.
           The rule is that a non-embedding partition only
           combines with another non-embedding one.
           So as the embedding partitions.
        """
        def combine_partitions_based_on_size(partitions: List[Partition],
                                             available_mem_bytes: int) -> None:
            """Combining small partitions together to keep as less partitions as possible.
               Here is an example of the algorithm to do this:
               Assume some partitions, we first sort them based on partiiton used memory size.
               [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)]
               The available memory is 10.
               step 1: self.find_partition_to_combine_based_on_size()
               First, mark bfs level for each partition
               Second, look the smallest partition, partition_4: 10 - 1 = 9
               It means any partition has a used memory equal or less than 9 could combine this partition
               We go from the largest and selection partition_0.
               Check the bfs level for two partitions, if the level difference is less than 2,
               it can be combined.
               step 2: repeat step 1 until no partitions can be combined
            """
            find_combination = True
            while find_combination:
                # Sort partitions based on memory size
                sorted_partitions = sorted(partitions,
                                           key=lambda p: p.used_mem_bytes)
                # Mark bfs level
                get_bfs_level_partition(self.partitions)
                find_combination, partitions = \
                    find_partition_to_combine_based_on_size(
                        sorted_partitions,
                        available_mem_bytes,
                        partitions
                    )
            return

        def calculate_mem_bytes_needed(p1, p2):
            """Given two partitions, calculate how many mem bytes
               are needed if two partitions are combined
            """
            nodes = p1.nodes.union(p2.nodes)
            mem_bytes_needed = 0
            for node in nodes:
                mem_bytes_needed += get_extra_size_of(node, nodes)
            return mem_bytes_needed

        def find_partition_to_combine_based_on_size(
                sorted_partitions: List[Partition], available_mem_bytes: int,
                partitions: List[Partition]) -> Tuple[bool, List[Partition]]:
            """step 1 in combine_partition_based_on_size()"""
            find_combination = False
            smallest_partition = sorted_partitions.pop(0)
            for p in sorted_partitions[::-1]:
                if abs(smallest_partition.bfs_level - p.bfs_level) <= 1:
                    # Calculate how many bytes needed if combined
                    mem_bytes_needed = calculate_mem_bytes_needed(
                        p, smallest_partition)
                    if mem_bytes_needed <= available_mem_bytes:
                        combine_two_partitions(p, smallest_partition,
                                               self.partitions)
                        partitions.remove(smallest_partition)
                        partitions.remove(p)
                        partitions.append(self.partitions[-1])
                        find_combination = True
                        break
            return find_combination, partitions

        def reset_partition_in_sparse_nn(partition, new_partition=True):
            """If crossing the boudary between non-embedding nodes and
               embedding nodes, create a new partition
            """
            if in_embedding_region:
                embedding_partitions.append(partition)
            else:
                non_embedding_partitions.append(partition)
            if new_partition:
                partition = self.create_partition()
                partition.left_mem_bytes = available_mem_bytes
                return partition
            return None

        def is_embedding_node(node: Node) -> bool:
            """Check if a node is an embedding node"""
            if node.op == 'call_module':
                submodule = self.graph_module
                for atom in str(node.target).split('.'):
                    if not hasattr(submodule, atom):
                        raise RuntimeError(
                            f'Module {submodule} has no attribute {atom}')
                    submodule = getattr(submodule, atom)
                    if 'Embedding' in str(submodule):
                        return True
            return False

        # Track embedding partitons and non-embedding partitions separately
        embedding_partitions: List[Partition] = []
        non_embedding_partitions: List[Partition] = []
        # A Flag to check the boundary
        in_embedding_region: bool = False
        partition = self.create_partition()
        for node in self.graph_module.graph.nodes:
            if node.op in {'call_module', 'call_method', 'call_function'}:
                # Check if crossing the boundary between embedding nodes and non embedding nodes
                if is_embedding_node(node) != in_embedding_region:
                    # Crossing the boundary
                    # Check if the current partition is an empty partition
                    if partition.used_mem_bytes != 0:
                        # The current partition isn't an empty partition. Create a new one.
                        partition = reset_partition_in_sparse_nn(partition)
                    in_embedding_region = not in_embedding_region
                total_size_of_input_nodes = get_extra_size_of(
                    node, partition.nodes)
                if total_size_of_input_nodes + partition.used_mem_bytes > available_mem_bytes:
                    partition = reset_partition_in_sparse_nn(partition)
                    total_size_of_input_nodes = get_extra_size_of(
                        node, partition.nodes)
                    if total_size_of_input_nodes > available_mem_bytes:
                        raise RuntimeError(node.target +
                                           'is too large to fit into a device')
                partition.add_node(node)
        reset_partition_in_sparse_nn(partition, new_partition=False)
        # Set parents and children for partitions
        set_parents_and_children(self.partitions)
        # Combining non-embedding partitions
        combine_partitions_based_on_size(non_embedding_partitions,
                                         available_mem_bytes)
        # Combining embedding partitions
        combine_partitions_based_on_size(embedding_partitions,
                                         available_mem_bytes)
        total_size_of_non_embedding_partitions = 0
        for partition in non_embedding_partitions:
            total_size_of_non_embedding_partitions += partition.used_mem_bytes
        # Check if devices are enough for all partitions
        if len(embedding_partitions) > len(self.devices):
            msg = 'Need ' + str(len(embedding_partitions)) + ' devices, but only ' \
                + str(len(self.devices)) + ' provided'
            raise RuntimeError(msg)
        occupied_devices = []
        for i, partition in enumerate(embedding_partitions):
            # Check if all non-embedding partitions can fit into embedding partition devices
            if total_size_of_non_embedding_partitions + partition.used_mem_bytes > available_mem_bytes:
                raise RuntimeError(
                    'partition_' + str(partition.partition_id) +
                    '(embedding partition) and non embedding partitions can not fit into one device'
                )
            else:
                # Add logical device to the partition
                partition.logical_device_ids = [self.devices[i].logical_id]
                occupied_devices.append(self.devices[i].logical_id)
        # Add logical devices to the non_embedding_partitions
        for partition in non_embedding_partitions:
            partition.logical_device_ids = occupied_devices
        # Get the node to partition mapping
        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
        return
コード例 #6
0
    def size_based_partition(self) -> None:
        """This method is to partition the fx module based on memory size.
           It uses greedy approach. The result may not be the best.
           The basic idea is:
           Step 1:
           Find a device which has enough memory to fit the current node, create a empty partition
           with the size of that device.
           Then keep adding the following nodes into the partition until the partition is full.
           Step 2:
           Repeat Step 1 until no device left
           Step 3:
           If some nodes are left, create a partition for each left node (single node partition).
           and then try to map those partitions into logical devices with enough mem left.
        """
        def find_device_based_on_size(node) -> Device:
            """Given a node, this function is to find a logical device
               that could fit the node.
            """
            mem_size_needed = get_extra_size_of(node, set())
            device = Device('', -1, -1)
            for d in self.devices:
                if d not in occupied_devices and d.available_mem_bytes >= mem_size_needed:
                    device = d
                    break
            if device.available_mem_bytes < 0:
                raise RuntimeError(
                    str(node) + 'is too large to fit any device')
            occupied_devices.append(device)
            return device

        # Track partition and its left mem size
        partition_to_left_mem_bytes: Dict[Partition, int] = {}
        # Track all the devices that have been used
        occupied_devices: List[Device] = []
        partition = self.create_partition()
        for node in self.graph_module.graph.nodes:
            if node.op in {'call_module', 'call_method', 'call_function'}:
                # Check if there are devices left
                if len(self.partitions) <= len(self.devices):
                    total_size_of_input_nodes = get_extra_size_of(
                        node, partition.nodes)
                    # Check if the current partition is the very first partition
                    if partition.used_mem_bytes == 0:
                        # Find a device to fit the first node, return available mem size
                        device = find_device_based_on_size(node)
                        occupied_devices.append(device)
                        # Update partition and its left mem size
                        partition_to_left_mem_bytes[
                            partition] = device.available_mem_bytes
                        # Update available mem for the current partitio
                        partition.logical_device_ids.append(device.logical_id)
                    else:
                        # The current partition is not the first partition
                        # Check if the current node can fit into current partition
                        if partition_to_left_mem_bytes[
                                partition] < total_size_of_input_nodes:
                            # Check if no device is left
                            if len(self.partitions) == len(self.devices):
                                # No device is left
                                # Put the previous partitions into a list (non_single_node_partitions)
                                non_single_node_partitions = self.partitions[:]
                                # Create the first single node partition for the current node
                                self.create_single_node_partition(node)
                                continue
                            # Some devices are still left
                            # Create a new partition with a mem size that is enough for the current node
                            device = find_device_based_on_size(node)
                            partition = self.create_partition()
                            total_size_of_input_nodes = get_extra_size_of(
                                node, partition.nodes)
                            partition_to_left_mem_bytes[
                                partition] = device.available_mem_bytes
                            partition.logical_device_ids.append(
                                device.logical_id)
                    partition.add_node(node)
                    partition_to_left_mem_bytes[
                        partition] -= total_size_of_input_nodes
                # Create single node partitions if no device is left
                else:
                    self.create_single_node_partition(node)
        reorganize_partitions(self.partitions)
        # Get the node to partition mapping
        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
        # Mapping all partitions into device
        found_partition_to_device_mapping = get_device_to_partitions_mapping(
            self.partitions, self.devices)
        if not found_partition_to_device_mapping:
            raise RuntimeError(
                "Cannot Get a Valid Partition to Logical Device Mapping")
        return