示例#1
0
    def step(self, closure=None, **kargs):
        # apply local gradient.
        with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_):
            self._apply_gradient()

        # Unflatten the saved hat params.
        with kargs["timer"]("grad.recover_hat_params", epoch=self.conf.epoch_):
            params, _ = get_data(self.param_groups, self.param_names, is_get_grad=False)
            grads, shapes = get_data(
                self.param_groups, self.param_names, is_get_grad=True
            )

        # compress.
        with kargs["timer"]("grad.compress", epoch=self.conf.epoch_):
            selected_values, selected_indices, n_bits = self._compress(grads)

        # sync.
        with kargs["timer"]("grad.sync", epoch=self.conf.epoch_):
            synced_message, message_size = self._sync(selected_values, selected_indices)

        # recover and update the neighbor hat params.
        with kargs["timer"]("grad.recover_info", epoch=self.conf.epoch_):
            updated_flatten_params = self._recover_info(
                flatten(params),
                synced_message,
                message_size,
                self.selected_shapes,
                shapes,
            )

        with kargs["timer"]("grad.update_model", epoch=self.conf.epoch_):
            # finally unflatten.
            unflatten(params, updated_flatten_params, shapes)
        return n_bits
示例#2
0
    def _compress(self, grads):
        selected_values, selected_indices, n_bits = [], [], []
        for (idx, param_name), grad in zip(self.param_names, grads):
            # add memory back.
            _grad = grad.data.view(-1) + self.memory_of_grads[param_name]

            # get values and indices
            compress_ratio = self._get_compress_ratio()
            _selected_values, _selected_indices, _n_bits = compress_or_quantize(
                grad=_grad,
                comm_op=self.comm_op,
                compressor_fn=self.compressor_fn,
                compress_ratio=compress_ratio,
                quantize_level=self.quantize_level,
                is_biased=self.is_biased,
            )
            selected_values.append(_selected_values)
            selected_indices.append(_selected_indices)
            n_bits.append(_n_bits)

            # update the memory
            if self.is_compress_op:
                _, nmask = self.compressor_fn.get_mask(_grad,
                                                       _selected_indices)
                self.memory_of_grads[param_name] = _grad * nmask

                # apply momentum factor masking.
                if self.mask_momentum:
                    self.state[self.param_groups[idx]["params"]
                               [0]]["momentum_buffer"].mul_(
                                   nmask.view(grad.size()))
            else:
                # self.memory_of_grads[param_name] = _grad - _selected_values
                pass

        # get selected shapes.
        self.selected_shapes = [len(_value) for _value in selected_values]

        # flatten selected values/indices.
        flatten_selected_values = flatten(selected_values)
        flatten_selected_indices = (flatten(selected_indices) if
                                    selected_indices[0] is not None else None)
        return flatten_selected_values, flatten_selected_indices, sum(n_bits)
示例#3
0
    def __init__(self, tensors, use_cuda=True):
        indices = [0]
        for tensor in tensors:
            new_end = indices[-1] + tensor.nelement()
            indices.append(new_end)

        self._start_idx = indices[:-1]
        self._end_idx = indices[1:]
        self._tensors = tensors

        self.buffer = flatten(tensors, use_cuda=use_cuda)  # copies