Exemple #1
0
    def __init__(
        self,
        module: nn.Module,
        sharded_optimizer: Union[OSS, List[OSS]],
        process_group: Any = None,
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
    ):
        super().__init__()

        self.module = module
        self.sharded_optimizers = [sharded_optimizer] if isinstance(
            sharded_optimizer, OSS) else sharded_optimizer
        self.enable_broadcast_buffers = broadcast_buffers

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
        self.should_accumulate_grads = False

        # Communication related attributes
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.world_size = dist.get_world_size(self.process_group)
        self.reference_global_rank = OSS.get_global_rank(
            self.process_group, 0)  # picking rank 0 as the reference
        self.rank = dist.get_rank(self.process_group)
        self.global_rank = OSS.get_global_rank(self.process_group, self.rank)

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
        devices = {p.device for p in self.module.parameters()}
        self.is_multi_device_module = len(devices) > 1
        self.device = list(devices)[0]

        distinct_device_types = {
            p.device.type
            for p in self.module.parameters()
        }
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
        # several optimizers can be present each working on seperate parameter sets,
        # we build an iterator which goes through all the parameters involved globally
        self._param_iterator = chain(*[
            optim.should_bucket_param.keys()
            for optim in self.sharded_optimizers
        ])
        self._grad_to_be_reduced = [True for _ in self._param_iterator]
        self._grad_accs: List[Callable] = []
        self._setup_backward_hooks()

        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()
Exemple #2
0
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[
                        index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(
                            self._flush_buckets)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    param.grad.mul_(self.world_size_scaling)

                    if self.reduce_fp16:
                        param.grad.data = param.grad.data.half()

                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self.global_rank:
                            param.grad = None
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(
                                dtype=param.dtype)

                    # Async reduce for this buffer, log the future
                    dst_global_rank = OSS.get_global_rank(
                        self.process_group, dst_rank)

                    self._work_handles.append(
                        Workhandle(
                            handle=dist.reduce(tensor=param.grad.data,
                                               dst=dst_global_rank,
                                               group=self.process_group,
                                               async_op=True),
                            callback=cleanup,
                        ))
                    self._reduced_grads += 1

                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()

                    # If all the reduce operations have been called,
                    # make sure that all the asynchronous calls have concluded before moving on
                    # and execute the delayed actions (release gradients, unroll the buckets)
                    if self._reduced_grads == self._reduced_grads_max:
                        self._consume_work_handles()
Exemple #3
0
    def __init__(
        self,
        module: nn.Module,
        sharded_optimizer: Union[OSS, List[OSS]],
        process_group: Any = None,
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
    ):
        super().__init__()

        self.module = module
        self.sharded_optimizers = [sharded_optimizer] if isinstance(
            sharded_optimizer, OSS) else sharded_optimizer
        self.enable_broadcast_buffers = broadcast_buffers

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
        self.should_accumulate_grads = False

        # Communication related attributes
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.world_size_scaling = 1.0 / dist.get_world_size(
            self.process_group)  # > 0
        self.reference_global_rank = OSS.get_global_rank(
            self.process_group, 0)  # picking rank 0 as the reference
        self.rank = dist.get_rank(self.process_group)
        self.global_rank = OSS.get_global_rank(self.process_group, self.rank)

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
        devices = {p.device for p in self.module.parameters()}
        self.is_multi_device_module = len(devices) > 1
        self.device = list(devices)[0]

        distinct_device_types = {
            p.device.type
            for p in self.module.parameters()
        }
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
        # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks

        # - we build an iterator which goes through all the parameters involved globally
        all_param_iterator = chain(*[
            sum([sum(p, []) for p in optim.per_device_params.values()], [])
            for optim in self.sharded_optimizers
        ])
        self._grad_to_be_reduced = [
            True for _ in filter(lambda x: x.requires_grad, all_param_iterator)
        ]

        # - keep track of the grads which have already been reduced
        self._reduced_grads: Dict[OSS, int] = {}
        self._reduced_grads_max = {
            o: len(o.param_to_rank.values())
            for o in self.sharded_optimizers
        }
        self._clear_counters()

        # - setup backward hooks which will be called by Torch's autograd in due time
        self._grad_accs: List[Callable] = []
        self._setup_backward_hooks()

        # passing a handle to torch.nn.SyncBatchNorm layer
        self._passing_sync_batchnorm_handle(self.module)

        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()
Exemple #4
0
    def __init__(
        self,
        module: nn.Module,
        sharded_optimizer: Union[OSS, List[OSS]],
        process_group: Any = None,
        broadcast_buffers: bool = True,
        sync_models_at_startup: bool = True,
        reduce_buffer_size: int = 2**23,
        auto_refresh_trainable: bool = True,
        reduce_fp16: bool = False,
    ):
        super().__init__()

        self.module = module
        self.sharded_optimizers = [
            sharded_optimizer
        ] if not isinstance(sharded_optimizer, list) else sharded_optimizer
        self.enable_broadcast_buffers = broadcast_buffers
        self.auto_refresh_trainable = auto_refresh_trainable
        self.reduce_fp16 = reduce_fp16
        if reduce_buffer_size > 0 and reduce_fp16:
            self.reduce_fp16 = False
            logging.warning(
                "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
            )

        # Handle a no_sync() context which prevents the gradient synchronization,
        # accumulate in place
        self.should_accumulate_grads = False
        self.accumulate_grads_flipped = False

        # Communication related attributes
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.backend = dist.get_backend(self.process_group)
        self.world_size_scaling = 1.0 / dist.get_world_size(
            self.process_group)  # > 0
        self.reference_global_rank = OSS.get_global_rank(
            self.process_group, 0)  # picking rank 0 as the reference
        self.rank = dist.get_rank(self.process_group)
        self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
        self._local_to_global_rank = [
            OSS.get_global_rank(self.process_group, i)
            for i in range(dist.get_world_size(self.process_group))
        ]

        # Expose some of the PytorchDDP attributes, some frameworks rely on them.
        # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
        # device_id related logic is not present, this is not handled
        devices = {p.device for p in self.module.parameters()}
        self.is_multi_device_module = len(devices) > 1
        self.device = list(devices)[0]

        distinct_device_types = {
            p.device.type
            for p in self.module.parameters()
        }
        assert len(distinct_device_types) == 1, (
            "ShardedDataParallel's input module must be on "
            "the same type of devices, but input module parameters are located on {} different device types."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        # Scafolding to be able to reduce the grads during the BW pass
        # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks

        # - we build an iterator which goes through all the parameters involved globally
        self._all_params = list(
            chain(*[
                sum([sum(p, []) for p in optim.per_device_params.values()], [])
                for optim in self.sharded_optimizers
            ]))
        self._trainable_params: List[torch.Tensor] = []
        self._grad_to_be_reduced: List[bool] = []
        self._trainable_param_to_rank: Dict[torch.Tensor, int] = {}
        self._reference_trainable_mask = list(map(_trainable,
                                                  self._all_params))

        # - setup buckets and tensor views
        model_size = sum([p.numel() for p in self.module.parameters()])
        self.buffer_max_size = min(reduce_buffer_size, model_size)

        if dist.get_world_size(self.process_group) == 1:
            self.buffer_max_size = 0
            logging.info(
                "Training is not really distributed, single rank. Deactivating buckets"
            )

        logging.info(
            "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
            .format(self.buffer_max_size / 2**20, model_size / 2**20))
        self.use_buckets = self.buffer_max_size > 0

        self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
        self._should_bucket_grad: List[bool] = []
        self._bucket_list: List[GradBucket] = []

        # - setup backward hooks which will be called by Torch's autograd in due time
        self._grad_accs: List[Callable] = []
        self._grad_hooks: List[Any] = []
        self._manual_reduce: List[Callable] = []

        # passing a handle to torch.nn.SyncBatchNorm layer
        self._passing_sync_batchnorm_handle(self.module)

        # Make sure that all ranks start with the same model
        if sync_models_at_startup:
            self._sync_params_and_buffers()

        self._work_handles: Deque[Workhandle] = deque()
        self._bucket_flush_callback_set = False
Exemple #5
0
    def _reduce_grads_task(buffers: List[torch.Tensor],
                           per_rank_params: List[List[Parameter]], group: Any,
                           self_rank: int, world_size: int) -> None:
        """Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for
        an opportunistic bucketing.

        NOTE: All param gradients are assumed to exist"""

        buffer_size = buffers[0].numel()
        bucket_requests = []
        requests = []

        for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
            # All the params are sorted per rank and per increasing size
            if len(params) == 0:
                continue

            for p in params:
                if p.grad is None:
                    p.grad = torch.zeros_like(p)

            global_rank = OSS.get_global_rank(group, rank)

            # Copy small gradients into per-GPU buffers and then async reduce
            i_bucketed = 0  # the number of tensors packed in the buffer
            offset = 0

            # Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
            while i_bucketed < len(params) and offset + params[
                    i_bucketed].numel() < buffer_size:
                end = offset + params[i_bucketed].numel()
                buffer[offset:end].copy_(
                    params[i_bucketed].grad.data.view(-1))  # type: ignore
                offset = end
                i_bucketed += 1

            if i_bucketed > 0:
                buffer.div_(world_size)  # type: ignore
                bucket_requests.append((
                    dist.reduce(tensor=buffer,
                                dst=global_rank,
                                group=group,
                                async_op=True),  # type: ignore
                    rank,
                ))

            # Directly reduce the other grads
            for p in params[i_bucketed:]:
                p.grad = cast(Tensor, p.grad)
                if p.grad.requires_grad:
                    raise RuntimeError(
                        "DistributedDataParallel only works with gradients that don't require grad"
                    )

                p.grad.div_(world_size)  # type: ignore
                requests.append(
                    dist.reduce(tensor=p.grad,
                                dst=global_rank,
                                group=group,
                                async_op=True))  # type: ignore

        # Unroll the initial packed small gradients, as soon as possible
        for future, rank in bucket_requests:
            future.wait()

            if rank == self_rank:
                i_bucketed = 0  # the number of tensors packed in the buffer
                offset = 0
                params = per_rank_params[rank]
                buffer = buffers[rank]

                while i_bucketed < len(params) and offset + params[
                        i_bucketed].numel() < buffer_size:
                    end = offset + params[i_bucketed].numel()
                    params[i_bucketed].grad.data.copy_(
                        buffer[offset:end].view_as(
                            params[i_bucketed]))  # type: ignore
                    offset = end
                    i_bucketed += 1

        # Make sure that we're done with this device before moving on and cleaning the unused params
        _ = list(map(lambda x: x.wait(), requests))
Exemple #6
0
    def _setup_bucket_strategy(self) -> None:
        """Devise a bucketing strategy on a per-rank ownership level.
        These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case.
        This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
        """

        # A priori, one reduce call per param
        self._reduced_grads_max = len(self._trainable_params)

        if not self.use_buckets:
            return

        # Devise the bucketing strategy. Parameters are already sorted, in that:
        # - these are only the trainable parameters, so they should produce grads
        # - they are sorted by increasing size
        self.buckets = {}

        for param in self._trainable_params:
            device = param.device
            dst_rank = self._trainable_param_to_rank[param]

            if param.device not in self.buckets.keys():
                self.buckets[param.device] = [
                    Bucket(buffer=torch.zeros(self.buffer_max_size,
                                              dtype=param.dtype,
                                              device=device))
                    for _ in range(dist.get_world_size(self.process_group))
                ]

            bucket = self.buckets[device][dst_rank]
            bucket.destination = OSS.get_global_rank(self.process_group,
                                                     dst_rank)

            # Criteria to decide whether this parameter is to be bucketed or not:
            # - enough room in the bucket
            if (bucket.fill + param.numel()) < self.buffer_max_size:
                self._should_bucket_grad.append(True)

                # This parameter gradients becomes a view of the bucket
                fill_next = bucket.fill + param.numel()

                if param.grad is None:
                    # will be overwritten just below, see next line
                    param.grad = torch.zeros_like(param)

                param.grad.data = bucket.buffer[bucket.fill:fill_next].view_as(
                    param.data)
                bucket.fill = fill_next

                # Update the bucket
                self._reduced_grads_max -= 1  # one less reduce call per bucketed grad
                self.buckets[device][dst_rank].max_params_checked_in += 1

            else:
                self._should_bucket_grad.append(False)

        self._bucket_list = list(
            chain(*[self.buckets[device] for device in self.buckets.keys()]))

        # Resize the buckets to remove lost space in the end
        for bucket in self._bucket_list:
            bucket.buffer.resize_(bucket.fill)
            bucket.sent = True
            if bucket.max_params_checked_in > 0:
                self._reduced_grads_max += 1  # one reduce call per bucket