Example #1
0
    def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info):
        # wait the sync.
        self.aggregator_fn.complete_wait(sync_buffer["sync_reqs_1"])
        self.aggregator_fn.complete_wait(sync_buffer["sync_reqs_2"])

        # uncompress and update.
        for rank, weight in neighbors_info.items():
            # get hat_params of the current rank.
            hat_params = neighbor_hat_params[rank if rank in
                                             neighbor_hat_params else "memory"]

            # recover the message and the corresponding device.
            sync_buffer["flatten_norms"].buffer = comm.recover_device(
                sync_buffer["synced_flatten_norms"][rank],
                device=hat_params.buffer.device)
            sync_buffer[
                "flatten_directions"].buffer = self.compressor_fn.uncompress(
                    comm.recover_device(sync_buffer["synced_signs"][rank],
                                        device=hat_params.buffer.device),
                    sync_buffer["sign_size"],
                )

            # update neighbor_hat_params
            for hat_param, hat_param_memory, norm, sign in zip(
                    hat_params,
                    neighbor_hat_params["memory"],
                    sync_buffer["flatten_norms"],
                    sync_buffer["flatten_directions"],
            ):
                _update = norm / sign.nelement() * sign
                if rank in neighbor_hat_params:
                    hat_param.add_(_update)
                hat_param_memory.add_(weight, _update)
Example #2
0
    def decompress(self, sync_buffer):
        # decompress and update.
        for rank in range(self.world_size):
            if rank == self.rank:
                continue

            # get grad_norm and build its tensorbuffer.
            _grad_norms = comm.recover_device(
                sync_buffer["synced_grad_norms"][rank],
                device=sync_buffer["synced_grads_tb"].buffer.device,
            )
            grad_norms_tb = TensorBuffer(_grad_norms)

            # get signs and build its tensorbuffer.
            signs = comm.recover_device(
                sync_buffer["synced_signs"][rank],
                device=sync_buffer["synced_grads_tb"].buffer.device,
            )
            _signs = self.compressor_fn.uncompress(signs,
                                                   sync_buffer["sign_size"])
            signs_tb = copy.deepcopy(sync_buffer["synced_grads_tb"])
            signs_tb.buffer = _signs

            # update grads.
            for grad_norm, sign, synced_grad in zip(
                    grad_norms_tb, signs_tb, sync_buffer["synced_grads_tb"]):
                _update = grad_norm * sign / synced_grad.nelement()
                synced_grad.add_(_update)

        # average grad.
        sync_buffer["synced_grads_tb"].buffer /= self.world_size * 1.0
        return sync_buffer["synced_grads_tb"]
Example #3
0
    def uncompress(self, sync_buffer, neighbors_info):
        aggregated_info_tb = deepcopy(sync_buffer["params_tb"])
        aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer)

        # uncompress and update.
        for rank in neighbors_info.keys():
            param_norms = sync_buffer["synced_param_norms"][rank]
            signs = sync_buffer["synced_signs"][rank]

            # recover the message and the corresponding device.
            param_norms = comm.recover_device(
                param_norms, device=sync_buffer["params_tb"].buffer.device
            )
            signs = self.compressor_fn.uncompress(
                comm.recover_device(
                    signs, device=sync_buffer["params_tb"].buffer.device
                ),
                sync_buffer["sign_size"],
            )

            # build the corresponding tensorbuffer.
            param_norms_tb = TensorBuffer(param_norms)
            signs_tb = deepcopy(sync_buffer["params_tb"])
            signs_tb.buffer = signs

            # accumulate information for the neighborhood..
            for _info, _param_norm, _sign in zip(
                aggregated_info_tb, param_norms_tb, signs_tb
            ):
                _info.add_(
                    self.consensus_stepsize
                    * (neighbors_info[rank] - (1 if rank == self.rank else 0))
                    * (_param_norm / _sign.nelement() * _sign)
                )
        return aggregated_info_tb
Example #4
0
    def decompress(self, sync_buffer):
        # wait the sync.
        self.aggregator_fn.complete_wait(sync_buffer["sync_req"])

        # init placeholder.
        synced_updates_tb = deepcopy(sync_buffer["grads_tb"])
        synced_updates_tb.buffer = torch.zeros_like(synced_updates_tb.buffer)

        # decompress and update.
        for rank in range(self.world_size):
            # get signs and build its tensorbuffer.
            synced_updates_tb.buffer += self.compressor_fn.uncompress(
                comm.recover_device(
                    sync_buffer["synced_signs"][rank],
                    device=sync_buffer["grads_tb"].buffer.device,
                ),
                sync_buffer["sign_size"],
            )

        # average grad.
        if self.majority_vote:
            synced_updates_tb.buffer = torch.sign(synced_updates_tb.buffer)
        else:
            synced_updates_tb.buffer /= self.world_size * 1.0
        return synced_updates_tb
Example #5
0
    def _recover_info(self, flatten_params, synced_message, message_size,
                      selected_shapes, shapes):
        # use the pointers to recover the info and get synced grad.
        _message_size = int(message_size / 2)

        if self.is_compress_op:
            empty_grads = torch.zeros_like(flatten_params)

            for message in synced_message:
                q_values, q_indices = self.compressor_fn.uncompress(
                    message[:_message_size],
                    message[_message_size:],
                    selected_shapes,
                    shapes,
                )

                empty_grads[q_indices] += q_values

            # get update tensor.
            _update = empty_grads / self.n_nodes
        else:
            # get update tensor.
            _update = synced_message / self.n_nodes

        # update flatten_params (assume the used lr is the same over params)
        updated_flatten_params = flatten_params.add(
            -self.param_groups[0]["lr"],
            recover_device(_update, device=flatten_params.device),
        )
        return updated_flatten_params
Example #6
0
    def uncompress(self, sync_buffer, neighbors_info):
        aggregated_info_tb = deepcopy(sync_buffer["params_tb"])
        aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer)

        # uncompress and update.
        sycned_message_size = int(sync_buffer["sycned_message_size"] / 2)

        for rank in neighbors_info.keys():
            _message = comm.recover_device(
                sync_buffer["synced_message"][rank],
                device=sync_buffer["params_tb"].buffer.device,
            )
            values = _message[:sycned_message_size]
            indices = _message[sycned_message_size:]

            # deal with unbalanced values/indieces
            q_values, q_indices = self.compressor_fn.uncompress(
                values,
                indices,
                sync_buffer["selected_shapes"],
                sync_buffer["original_shapes"],
            )

            # update the flatten hat params.
            aggregated_info_tb.buffer[q_indices] += (
                self.consensus_stepsize
                * (neighbors_info[rank] - (1 if rank == self.rank else 0))
                * q_values
            )
        return aggregated_info_tb
Example #7
0
    def uncompress(self, sync_buffer, neighbor_hat_params, local_index):
        sycned_message_size = int(sync_buffer["sycned_message_size"] / 2)

        # uncompress and update.
        for rank, hat_params in neighbor_hat_params.items():
            _message = comm.recover_device(
                sync_buffer["synced_message"][rank], device=hat_params.buffer.device
            )
            values = _message[:sycned_message_size]
            indices = _message[sycned_message_size:]

            # deal with unbalanced values/indieces
            q_values, q_indices = self.compressor_fn.uncompress(
                values,
                indices,
                sync_buffer["selected_shapes"],
                sync_buffer["original_shapes"],
            )

            # update the flatten hat params.
            hat_params.buffer[q_indices] = (
                hat_params.buffer[q_indices]
                .mul(1 - 2 / local_index)
                .add(2 / local_index, q_values)
            )
Example #8
0
    def uncompress(self, sync_buffer, neighbor_hat_params):
        # uncompress and update.
        for rank, hat_params in neighbor_hat_params.items():
            # map the tensors to the correct location.
            _message = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)

            # update the flatten hat params.
            hat_params.buffer.add_(_message)
Example #9
0
    def uncompress(self, sync_buffer, neighbor_hat_params):
        # uncompress and update.
        for rank, hat_params in neighbor_hat_params.items():
            # recover the message and the corresponding device.
            sync_buffer["flatten_norms"].buffer = comm.recover_device(
                sync_buffer["synced_flatten_norms"][rank],
                device=hat_params.buffer.device,
            )
            sync_buffer[
                "flatten_updates"].buffer = self.compressor_fn.uncompress(
                    comm.recover_device(sync_buffer["synced_signs"][rank],
                                        device=hat_params.buffer.device),
                    sync_buffer["sign_size"],
                )

            # update hat_params.
            for hat_param, norm, sign in zip(hat_params,
                                             sync_buffer["flatten_norms"],
                                             sync_buffer["flatten_updates"]):
                # update the flatten hat params.
                hat_param.add_(norm / sign.nelement(), sign)
Example #10
0
    def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info):
        # wait the sync.
        self.aggregator_fn.complete_wait(sync_buffer["sync_reqs"])

        for rank, weight in neighbors_info.items():
            hat_params = neighbor_hat_params[rank if rank in
                                             neighbor_hat_params else "memory"]
            hat_params_memory = neighbor_hat_params["memory"]

            # recover correct values/indices.
            q_values = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)

            # update neighbor_hat_params
            if rank in neighbor_hat_params:
                hat_params.buffer += q_values
            hat_params_memory.buffer += weight * q_values
Example #11
0
    def uncompress(self, sync_buffer, neighbors_info):
        aggregated_info_tb = deepcopy(sync_buffer["params_tb"])
        aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer)

        # uncompress and update.
        for rank in neighbors_info.keys():
            # map the tensors to the correct location.
            _message = comm.recover_device(
                sync_buffer["synced_message"][rank],
                device=sync_buffer["params_tb"].buffer.device,
            )

            # update the flatten hat params.
            aggregated_info_tb.buffer.add_(
                self.consensus_stepsize
                * (neighbors_info[rank] - (1 if rank == self.rank else 0))
                * _message
            )
        return aggregated_info_tb
Example #12
0
    def _uncompress_helper(
        self,
        _hat_params,
        _rank,
        synced_message,
        sycned_message_size,
        selected_shapes,
        original_shapes,
    ):
        # recover the message and the corresponding device.
        _message = comm.recover_device(synced_message[_rank],
                                       device=_hat_params.buffer.device)
        values = _message[:sycned_message_size]
        indices = _message[sycned_message_size:]

        # deal with unbalanced values/indieces
        q_values, q_indices = self.compressor_fn.uncompress(
            values, indices, selected_shapes, original_shapes)
        return q_values, q_indices