コード例 #1
0
class ECDQuantizationCompressor(object):
    def __init__(self, aggregator, comm_op, comm_device, compress_ratio,
                 quantize_level, is_biased, backend, use_ipc, **kargs):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

    def compress(self, sync_buffer):
        # get the sign/magnitude for the tensor (to be transmitted).
        quantized_values = []

        for flatten_updated_param in sync_buffer["flatten_updated_params"]:
            _quantized_values = self.compressor_fn.compress(
                flatten_updated_param, self.comm_op, self.quantize_level,
                self.is_biased)
            quantized_values.append(_quantized_values)

        # flatten selected values/indices.
        flatten_updates = TensorBuffer(quantized_values)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32

        # update shared dict.
        sync_buffer["flatten_updates"] = flatten_updates
        sync_buffer["n_bits"] = n_bits

    def sync(self, sync_buffer):
        # prepare the sync.
        to_sync_message = sync_buffer["flatten_updates"].buffer

        if self.comm_device == "cpu":
            to_sync_message = to_sync_message.cpu().pin_memory()

        # sync.
        synced_message = self.aggregator_fn._agg(to_sync_message,
                                                 op="get_raw_sync_data",
                                                 force_wait=True)

        # update sync_buffer.
        sync_buffer["synced_message"] = synced_message

    def uncompress(self, sync_buffer, neighbor_hat_params, local_index):
        # uncompress and update.
        for rank, hat_params in neighbor_hat_params.items():
            # map the tensors to the correct location.
            _message = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)

            # update the flatten hat params.
            hat_params.buffer.mul_(1 - 2 / local_index).add_(
                2 / local_index, _message)
コード例 #2
0
ファイル: deep_squeeze.py プロジェクト: VhalPurohit/CHOCOSGD
class DeepSqueezeQuantizationCompressor(object):
    def __init__(
        self,
        aggregator,
        rank,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        consensus_stepsize,
        **kargs
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.rank = rank
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.consensus_stepsize = consensus_stepsize
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

    def compress(self, sync_buffer):
        # get the sign/magnitude for the tensor (to be transmitted).
        quantized_values = []

        # compress and get compressed model.
        local_compressed_params_tb = deepcopy(sync_buffer["params_tb"])
        local_compressed_params_tb.buffer = torch.zeros_like(
            local_compressed_params_tb.buffer
        )
        for param, local_compressed_param in zip(
            sync_buffer["params_tb"], local_compressed_params_tb
        ):
            # quantize.
            _quantized_values = self.compressor_fn.compress(
                param, self.comm_op, self.quantize_level, self.is_biased
            )
            quantized_values.append(_quantized_values)

            # update the local compressed params.
            local_compressed_param.data.copy_(_quantized_values)

        # flatten selected values/indices.
        flatten_updates = TensorBuffer(quantized_values)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32

        # update shared dict.
        sync_buffer["flatten_updates"] = flatten_updates
        sync_buffer["n_bits"] = n_bits
        return local_compressed_params_tb

    def sync(self, sync_buffer):
        # prepare the sync.
        to_sync_message = sync_buffer["flatten_updates"].buffer

        if self.comm_device == "cpu":
            to_sync_message = to_sync_message.cpu().pin_memory()

        # sync.
        synced_message = self.aggregator_fn._agg(
            to_sync_message, op="get_raw_sync_data", force_wait=True
        )

        # update sync_buffer.
        sync_buffer["synced_message"] = synced_message

    def uncompress(self, sync_buffer, neighbors_info):
        aggregated_info_tb = deepcopy(sync_buffer["params_tb"])
        aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer)

        # uncompress and update.
        for rank in neighbors_info.keys():
            # map the tensors to the correct location.
            _message = comm.recover_device(
                sync_buffer["synced_message"][rank],
                device=sync_buffer["params_tb"].buffer.device,
            )

            # update the flatten hat params.
            aggregated_info_tb.buffer.add_(
                self.consensus_stepsize
                * (neighbors_info[rank] - (1 if rank == self.rank else 0))
                * _message
            )
        return aggregated_info_tb
コード例 #3
0
ファイル: parallel_choco.py プロジェクト: stjordanis/ChocoSGD
class CHOCOQuantizationCompressor(object):
    def __init__(
        self,
        aggregator,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        **kargs,
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

        # define gossip_stream
        if self.comm_device == "cpu":
            self.gossip_stream = torch.cuda.current_stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()

    def pipeline(self, sync_buffer, neighbor_hat_params, neighbors_info):
        with torch.cuda.stream(self.gossip_stream):
            try:
                self.compress(sync_buffer)
                self.sync(sync_buffer)
                self.uncompress(sync_buffer, neighbor_hat_params,
                                neighbors_info)
            except RuntimeError as e:
                print("Error: {}".format(e))

    def compress(self, sync_buffer):
        quantized_values = []

        for half_param, hat_param in zip(sync_buffer["flatten_params"],
                                         sync_buffer["flatten_hat_params"]):
            _quantized_values = self.compressor_fn.compress(
                half_param - hat_param,
                self.comm_op,
                self.quantize_level,
                self.is_biased,
            )
            quantized_values.append(_quantized_values)

        # flatten selected values/indices.
        flatten_updates = TensorBuffer(quantized_values)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32

        # update shared dict.
        sync_buffer["flatten_updates"] = flatten_updates
        sync_buffer["n_bits"] = n_bits

    def sync(self, sync_buffer):
        # prepare the sync.
        to_sync_message = sync_buffer["flatten_updates"].buffer

        if self.comm_device == "cpu":
            to_sync_message = to_sync_message.cpu().pin_memory()

        # sync.
        sync_message_reqs, synced_message = self.aggregator_fn._agg(
            to_sync_message, op="get_raw_sync_data", force_wait=False)

        # update sync_buffer.
        sync_buffer["sync_reqs"] = sync_message_reqs
        sync_buffer["synced_message"] = synced_message

    def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info):
        # wait the sync.
        self.aggregator_fn.complete_wait(sync_buffer["sync_reqs"])

        for rank, weight in neighbors_info.items():
            hat_params = neighbor_hat_params[rank if rank in
                                             neighbor_hat_params else "memory"]
            hat_params_memory = neighbor_hat_params["memory"]

            # recover correct values/indices.
            q_values = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)

            # update neighbor_hat_params
            if rank in neighbor_hat_params:
                hat_params.buffer += q_values
            hat_params_memory.buffer += weight * q_values