Ejemplo n.º 1
0
def _allgather_buffer(layer_id,
                      trainable_params,
                      group,
                      use_calc_stream,
                      task_flow,
                      sync_wait=False):
    for param in trainable_params[layer_id]:
        if param.status == "all":
            param.use_count += 1
            continue
        with paddle.amp.auto_cast(enable=False):
            full_param = _all_gather(param.fw_storage,
                                     group,
                                     use_calc_stream=use_calc_stream)
        if sync_wait:
            with paddle.amp.auto_cast(enable=False):
                dist.wait(tensor=full_param,
                          group=group,
                          use_calc_stream=use_calc_stream)
            core.VarBase(full_param._slice(
                0, param._numel()))._share_buffer_to(param)
            param.value().get_tensor()._set_dims(param.shape)
            param.fw_storage._clear()
            param.fw_storage = None
            param.status = "all"
            param.use_count += 1
        task_flow.full_param[param.name] = full_param
    return task_flow
Ejemplo n.º 2
0
    def get_all_parameters(self):
        assert len(self._trainable_params.keys()) > 0
        current_layer_params = self._layer.parameters(include_sublayers=True)
        trainable_params = list(
            filter(lambda x: x.trainable, current_layer_params))
        for param in trainable_params:
            if param.use_count > 0:
                continue
            assert hasattr(param, "fw_storage"
                           ), "Find {} don't have fw_storage attribute".format(
                               param.name)

            full_param = _all_gather(param.fw_storage,
                                     self._group,
                                     use_calc_stream=True)
            dist.wait(tensor=full_param,
                      group=self._group,
                      use_calc_stream=True)
            core.VarBase(full_param._slice(
                0, param._numel()))._share_buffer_to(param)
            param.value().get_tensor()._set_dims(param.shape)
            param.fw_storage._clear()
            param.fw_storage = None
            param.status = "all"
            param.use_count += 1

        self._optim._parameter_list = self._ori_parameter_list
        self._optim._param_groups = self._ori_param_groups
Ejemplo n.º 3
0
            def reduce(*_):
                # Skip gradient reduction, do not change status information
                if self._grad_reduced[index]:
                    assert param.grad is not None, "Parameter gradient cannot be None"

                    # Change reduce information
                    self._grad_reduced[index] = False
                    grad_storage = self._grad_storages[param.dtype][dst_rank]
                    grad_storage.params_checked_in += 1

                    if grad_storage.all_checked_in:
                        assert grad_storage.buffer is not None

                        # Normalize all ranks grad_storage
                        if not self._accumulate_grads:
                            grad_storage.buffer.scale_(
                                scale=self._world_size_scaling)

                        # Clearing up the grad_storage buffer
                        def cleanup():
                            if dst_rank != self._rank:
                                for p in grad_storage._params:
                                    p.clear_gradient(False)
                                    p._gradient_set_empty(False)

                                grad_storage.buffer.value().get_tensor()._clear(
                                )
                            elif self._offload:
                                grad_storage.to(device=self._offload_device)
                                for p in grad_storage._params:
                                    self._sharding_optimizers[
                                        0]._offload_acc_grad(
                                            p.name,
                                            p.grad.cast(dtype=Type.fp32.value))
                                    p.clear_gradient(False)
                                    p._gradient_set_empty(False)
                                grad_storage._device = self._default_device
                                grad_storage.buffer.value().get_tensor()._clear(
                                )

                        # Reduce the bucket
                        grad_storage.sent = True
                        self._tasks_flow.append(
                            Taskflow(
                                task=dist.reduce(
                                    tensor=grad_storage.buffer,
                                    dst=grad_storage.destination,
                                    group=self._group,
                                    use_calc_stream=True),
                                callback=cleanup))

                        # Multi stream operation will be supported later
                        dist.wait(
                            tensor=grad_storage.buffer,
                            group=self._group,
                            use_calc_stream=True)

                    # Clear the task flow and trigger callback to clear the redundant gradient
                    self._clear_task_flow()
Ejemplo n.º 4
0
 def _sync_buffers(self):
     for buffer in self._layer.buffers(include_sublayers=True):
         dist.broadcast(buffer,
                        self._global_root_rank,
                        self._group,
                        use_calc_stream=True)
     # Multi stream operation will be supported later
     dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
Ejemplo n.º 5
0
    def _sync_params_and_buffers(self):
        """
        Sync all model states for all ranks
        """

        for p in self._layer.parameters():
            dist.broadcast(p,
                           src=self._global_root_rank,
                           group=self._group,
                           use_calc_stream=True)

        # Multi stream operation will be supported later
        dist.wait(tensor=p, group=self._group, use_calc_stream=True)
Ejemplo n.º 6
0
    def __sync_buffers(self):
        """
        Sync all the param buffers from all ranks (exp: batch norm statistics).
        """

        for buffer in self._layer.buffers(include_sublayers=True):
            dist.broadcast(
                buffer,
                self._global_root_rank,
                self._group,
                use_calc_stream=True)
        # Multi stream operation will be supported later
        dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
    def _broadcast_params(self):
        """Broadcast the parameters of the current rank to each rank"""

        assert self._default_device == "gpu", "Only supported gpu"

        # Exchange all the shards with the other ranks
        for dtype_per_rank in self.param_storages.values():
            for dst_rank, internal_storage in dtype_per_rank.items():
                dist.broadcast(
                    tensor=internal_storage.buffer,
                    src=dst_rank,
                    group=self.group,
                    use_calc_stream=True)

            # Multi stream operation will be supported later
            dist.wait(
                tensor=internal_storage.buffer,
                group=self.group,
                use_calc_stream=True)
Ejemplo n.º 8
0
        def reduce(*_):
            if param.name in self._task_flow.full_grad.keys():
                full_grad = self._task_flow.full_grad[param.name]
                with paddle.amp.auto_cast(enable=False):
                    if not self._accumulate_grads:
                        full_grad.scale_(scale=self._world_size_scaling)
                    # Only support sync allreduce current rank's layer now
                    dist.all_reduce(tensor=full_grad,
                                    group=self._group,
                                    use_calc_stream=True)
                    dist.wait(tensor=full_grad,
                              group=self._group,
                              use_calc_stream=True)

                    start, end = self._param2buffer[param.name][self._rank]
                    if not self._accumulate_grads or param.bw_storage is None:
                        param.bw_storage = core.VarBase(
                            full_grad._slice(start, end)).detach().clone()
                    else:
                        param.bw_storage.add_(
                            core.VarBase(full_grad._slice(
                                start, end)).detach().clone())
                param.clear_gradient(False)
                param._gradient_set_empty(False)
                tmp_var = self._task_flow.full_grad.pop(param.name)
                tmp_var._clear()

            if param.name in self._task_flow.full_param.keys():
                if param.status == "all":
                    param.use_count = 0
                    param._clear()
                    start, end = self._param2buffer[param.name][self._rank]
                    with paddle.amp.auto_cast(enable=False):
                        param.fw_storage = core.VarBase(
                            self._task_flow.full_param[param.name]._slice(
                                start, end),
                            param.name + "@slice").detach().clone()
                    param.status = "part"
                    tmp_var = self._task_flow.full_param.pop(param.name)
                    tmp_var._clear()
Ejemplo n.º 9
0
            def reduce(*_):
                # Skip gradient reduction, do not change status information
                if self._grad_reduced[index]:
                    assert param.grad is not None, "Parameter gradient cannot be None"

                    # Change reduce information
                    self._grad_reduced[index] = False
                    if not self._accumulate_grads:
                        param.grad.scale_(scale=self._world_size_scaling)
                        param._reset_grad_inplace_version(True)

                    # Clear the gradient that does not belong to the current rank through the callback function
                    def cleanup():
                        if dst_rank != self._rank:
                            param.clear_gradient(False)
                        elif self._offload:
                            self._sharding_optimizers[0]._offload_acc_grad(
                                param.name,
                                param.grad.cast(dtype=Type.fp32.value).cpu())
                            param.clear_gradient(False)

                    # Synchronize the reduce parameter gradient
                    self._tasks_flow.append(
                        Taskflow(
                            task=dist.reduce(
                                tensor=param.grad,
                                dst=dst_rank,
                                group=self._group,
                                use_calc_stream=True),
                            callback=cleanup))

                    # Multi stream operation will be supported later
                    dist.wait(
                        tensor=param.grad,
                        group=self._group,
                        use_calc_stream=True)

                    # Clear the task flow and trigger callback to clear the redundant gradient
                    self._clear_task_flow()