Beispiel #1
0
        def reduce(*_: Any) -> None:
            # Skip gradient reduction, do not alter status flags
            if not self.should_accumulate_grads and self._grad_to_be_reduced[
                    index]:
                assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                # Make sure that this is not fired twice
                self._grad_to_be_reduced[index] = False

                if not self.use_buckets or not self._should_bucket_grad[index]:
                    param.grad.mul_(self.world_size_scaling)

                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self.global_rank:
                            param.grad = None

                    # Async reduce for this buffer, log the future
                    optimizer.work_handles.append(
                        Workhandle(
                            handle=dist.reduce(tensor=param.grad.data,
                                               dst=dst_rank,
                                               group=self.process_group,
                                               async_op=True),
                            callback=cleanup,
                        ))
                    self._reduced_grads[optimizer] += 1
                else:
                    bucket = self.buckets[optimizer][param.device][dst_rank]
                    bucket.params_checked_in += 1

                    if bucket.full():
                        # Normalize the bucket in one go
                        bucket.buffer.mul_(self.world_size_scaling)

                        # Reduce the bucket
                        bucket.sent = True
                        optimizer.work_handles.append(
                            Workhandle(
                                handle=dist.reduce(
                                    tensor=bucket.buffer,
                                    dst=dst_rank,
                                    group=self.process_group,
                                    async_op=True,
                                ),
                                callback=None,
                            ))
                        self._reduced_grads[optimizer] += 1

                # Opportunistically try to empty the queue
                optimizer._try_consume_work_handle()

                # If all the reduce operations have been called,
                # make sure that all the asynchronous calls have concluded before moving on
                # and execute the delayed actions (release gradients, unroll the buckets)
                if self._reduced_grads[optimizer] == self._reduced_grads_max[
                        optimizer]:
                    optimizer._consume_work_handles()
Beispiel #2
0
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(self._flush_reduce_calls)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    bucket = self.buckets[param.device][dst_rank]
                    bucket.params_checked_in += 1

                    if bucket.full():
                        # Normalize the bucket in one go
                        bucket.buffer.mul_(self.world_size_scaling)

                        # Reduce the bucket
                        bucket.sent = True
                        self._work_handles.append(
                            Workhandle(
                                handle=dist.reduce(
                                    tensor=bucket.buffer,
                                    dst=bucket.destination,
                                    group=self.process_group,
                                    async_op=True,
                                ),
                                callback=None,
                            )
                        )

                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()
Beispiel #3
0
        def reduce_direct(*_: Any) -> None:
            # Skip gradient reduction, do not alter status flags
            if not self.should_accumulate_grads and self._grad_to_be_reduced[
                    index]:
                assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                # Make sure that this is not fired twice
                self._grad_to_be_reduced[index] = False
                param.grad /= self.world_size

                # Future work includes clearing up the buffer if possible
                def cleanup() -> None:
                    if dst_rank != self.global_rank:
                        param.grad = None

                # Async reduce for this buffer, log the future
                optimizer.work_handles.append(
                    Workhandle(
                        handle=dist.reduce(tensor=param.grad.data,
                                           dst=dst_rank,
                                           group=self.process_group,
                                           async_op=True),
                        callback=cleanup,
                    ))

                # If all the reduce operations have been called, add the gatekeeper
                if len(optimizer.work_handles) == optimizer._max_work_handles:
                    gatekeeper()
Beispiel #4
0
        def reduce_bucket(*_: Any) -> None:
            # Skip gradient reduction, do not alter status flags
            if not self.should_accumulate_grads and self._grad_to_be_reduced[
                    index]:
                assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                # Make sure that this is not fired twice
                self._grad_to_be_reduced[index] = False

                # Copy to the flat buffer, update the buffer state
                bucket = optimizer.buckets[param.device][dst_rank]

                assert bucket.append(
                    param, use_gradient=True
                ), "Bucket overflow: max %s - current %s - adding %s" % (
                    bucket.max_size,
                    bucket.current_offset,
                    param.grad.numel(),
                )

                if bucket.full():

                    def unwrap() -> None:
                        for flat in bucket.params:
                            if dst_rank != self.global_rank:
                                # this rank is not the owner, release the grad
                                flat.param.grad = None
                            else:
                                # this rank is the owner, unroll the results
                                assert flat.param.grad is not None

                                flat.param.grad.data.copy_(
                                    bucket.buffer[flat.start:flat.stop].
                                    view_as(flat.param.data),
                                    non_blocking=True)

                        bucket.reset()

                    bucket.buffer /= self.world_size

                    optimizer.work_handles.append(
                        Workhandle(
                            handle=dist.reduce(
                                tensor=bucket.buffer,
                                dst=dst_rank,
                                group=self.process_group,
                                async_op=True,
                            ),
                            callback=unwrap,
                        ))

                    # If all the reduce operations have been called, add the gatekeeper
                    if len(optimizer.work_handles
                           ) == optimizer._max_work_handles:
                        gatekeeper()
Beispiel #5
0
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[
                        index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(
                            self._flush_buckets)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    param.grad.mul_(self.world_size_scaling)

                    if self.reduce_fp16:
                        param.grad.data = param.grad.data.half()

                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self.global_rank:
                            param.grad = None
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(
                                dtype=param.dtype)

                    # Async reduce for this buffer, log the future
                    dst_global_rank = OSS.get_global_rank(
                        self.process_group, dst_rank)

                    self._work_handles.append(
                        Workhandle(
                            handle=dist.reduce(tensor=param.grad.data,
                                               dst=dst_global_rank,
                                               group=self.process_group,
                                               async_op=True),
                            callback=cleanup,
                        ))
                    self._reduced_grads += 1

                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()

                    # If all the reduce operations have been called,
                    # make sure that all the asynchronous calls have concluded before moving on
                    # and execute the delayed actions (release gradients, unroll the buckets)
                    if self._reduced_grads == self._reduced_grads_max:
                        self._consume_work_handles()
Beispiel #6
0
    def _flush_reduce_calls(self) -> None:
        if self._bucket_list is not None:
            for bucket in self._bucket_list:
                if not bucket.sent:
                    # Normalize the bucket in one go
                    bucket.buffer.mul_(self.world_size_scaling)

                    # Reduce the bucket
                    self._work_handles.append(
                        Workhandle(
                            handle=dist.reduce(
                                tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
                            ),
                            callback=None,
                        )
                    )
                    bucket.sent = True

        self._consume_work_handles()
Beispiel #7
0
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[
                        index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    if not self._bucket_flush_callback_set:
                        Variable._execution_engine.queue_callback(
                            self._flush_reduce_calls)
                        self._bucket_flush_callback_set = True

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    param.grad.mul_(self.world_size_scaling)

                    if self.reduce_fp16:
                        param.grad.data = param.grad.data.half()

                    # Future work includes clearing up the buffer if possible
                    def cleanup() -> None:
                        if dst_rank != self.global_rank:
                            param.grad = None
                        else:
                            assert param.grad is not None
                            param.grad.data = param.grad.data.to(
                                dtype=param.dtype)

                    # Async reduce for this buffer, log the future
                    self._work_handles.append(
                        Workhandle(
                            handle=dist.reduce(
                                tensor=param.grad.data,
                                dst=self._local_to_global_rank[dst_rank],
                                group=self.process_group,
                                async_op=True,
                            ),
                            callback=cleanup,
                        ))

                    # Opportunistically try to empty the queue, free memory
                    self._try_consume_work_handle()
Beispiel #8
0
        def reduce_bucket(*_: Any) -> None:
            # Skip gradient reduction, do not alter status flags
            if not self.should_accumulate_grads and self._grad_to_be_reduced[
                    index]:
                assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                # Make sure that this is not fired twice
                self._grad_to_be_reduced[index] = False

                # Copy to the flat buffer, update the buffer state
                bucket = optimizer.buckets[param.device][dst_rank]

                assert bucket.append(
                    param, use_gradient=True
                ), "Bucket overflow: max %s - current %s - adding %s" % (
                    bucket.max_size,
                    bucket.current_offset,
                    param.grad.numel(),
                )

                if bucket.full():
                    bucket.buffer /= self.world_size

                    optimizer.work_handles.append(
                        Workhandle(
                            handle=dist.reduce(
                                tensor=bucket.buffer,
                                dst=dst_rank,
                                group=self.process_group,
                                async_op=True,
                            ),
                            callback=bucket.unroll,
                        ))

                    # If all the reduce operations have been called, add the gatekeeper
                    if len(optimizer.work_handles
                           ) == optimizer._max_work_handles:
                        gatekeeper()
Beispiel #9
0
            def reduce(*_: Any) -> None:
                # Skip gradient reduction, do not alter status flags
                if not self.should_accumulate_grads and self._grad_to_be_reduced[
                        index]:
                    assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"

                    # Make sure that this is not fired twice
                    self._grad_to_be_reduced[index] = False
                    bucket = self.buckets[param.device][dst_rank]
                    bucket.params_checked_in += 1

                    if bucket.full():
                        # Normalize the bucket in one go
                        bucket.buffer.mul_(self.world_size_scaling)

                        # Reduce the bucket
                        bucket.sent = True
                        self._work_handles.append(
                            Workhandle(
                                handle=dist.reduce(
                                    tensor=bucket.buffer,
                                    dst=bucket.destination,
                                    group=self.process_group,
                                    async_op=True,
                                ),
                                callback=None,
                            ))
                        self._reduced_grads += 1

                    # Opportunistically try to empty the queue
                    self._try_consume_work_handle()

                    # If all the reduce operations have been called,
                    # make sure that all the asynchronous calls have concluded before moving on
                    # and execute the delayed actions (release gradients, unroll the buckets)
                    if self._reduced_grads == self._reduced_grads_max:
                        self._consume_work_handles()