Esempio n. 1
0
 def __init__(
     self,
     aggregator,
     rank,
     comm_op,
     comm_device,
     compress_ratio,
     quantize_level,
     is_biased,
     backend,
     use_ipc,
     consensus_stepsize,
     **kargs
 ):
     # assign the common hyper-parameters
     self.aggregator_fn = aggregator
     self.rank = rank
     self.comm_op = comm_op
     self.comm_device = comm_device
     self.compress_ratio = compress_ratio
     self.quantize_level = quantize_level
     self.is_biased = is_biased
     self.backend = backend
     self.use_ipc = use_ipc
     self.consensus_stepsize = consensus_stepsize
     self.kargs = kargs
     self.compressor_fn = SparsificationCompressor()
Esempio n. 2
0
    def __init__(
        self,
        aggregator,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        **kargs,
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = SparsificationCompressor()

        # define gossip_stream
        if self.comm_device == "cpu":
            self.gossip_stream = torch.cuda.current_stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()
Esempio n. 3
0
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(DGC, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf
        self.n_nodes = conf.graph.n_nodes
        self.rank = conf.graph.rank

        # define the aggregator.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))
        self.world_aggregator = get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # related to sparsification/quantization.
        self.comm_op = conf.comm_op
        self.comm_device = conf.comm_device
        self.is_compress_op = "compress" in self.comm_op
        self.compress_ratio = conf.compress_ratio
        self.compress_warmup_values = conf.compress_warmup_values
        self.compress_warmup_epochs = conf.compress_warmup_epochs
        self.quantize_level = conf.quantize_level
        self.is_biased = conf.is_biased

        self.clip_grad = conf.clip_grad
        self.clip_grad_val = conf.clip_grad_val
        self.mask_momentum = conf.mask_momentum

        self.init_memory()
        self.init_compression()

        # define compressors.
        if self.is_compress_op:
            self.compressor_fn = SparsificationCompressor()
        else:
            self.compressor_fn = QuantizationCompressor()

        # define reducer.
        self.backend = conf.backend
Esempio n. 4
0
class DCDSparsificationCompressor(object):
    def __init__(self, aggregator, comm_op, comm_device, compress_ratio,
                 quantize_level, is_biased, backend, use_ipc, **kargs):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = SparsificationCompressor()

    def compress(self, sync_buffer):
        # get the sign/magnitude for the tensor (to be transmitted).
        selected_values, selected_indices = [], []

        for half_param, hat_param in zip(sync_buffer["flatten_half_params"],
                                         sync_buffer["flatten_params"]):
            _selected_values, _selected_indices = self.compressor_fn.compress(
                half_param - hat_param,
                self.comm_op,
                self.compress_ratio,
                self.is_biased,
            )
            selected_values.append(_selected_values)
            selected_indices.append(_selected_indices)

        # get selected shapes.
        selected_shapes = [len(_value) for _value in selected_values]

        # flatten selected values/indices.
        flatten_selected_values = TensorBuffer(selected_values)
        flatten_selected_indices = TensorBuffer(selected_indices)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_selected_values.buffer) + get_n_bits(
            flatten_selected_indices.buffer)

        # update shared dict.
        sync_buffer["selected_shapes"] = selected_shapes
        sync_buffer["flatten_selected_values"] = flatten_selected_values
        sync_buffer["flatten_selected_indices"] = flatten_selected_indices
        sync_buffer["n_bits"] = n_bits

    def sync(self, sync_buffer):
        # get the flatten values.
        message_to_send = torch.cat([
            sync_buffer["flatten_selected_values"].buffer,
            sync_buffer["flatten_selected_indices"].buffer,
        ])

        if self.comm_device == "cpu":
            message_to_send = message_to_send.cpu().pin_memory()

        # sync.
        synced_message = self.aggregator_fn._agg(message_to_send,
                                                 op="get_raw_sync_data",
                                                 force_wait=True)

        # update sync_buffer.
        sync_buffer["synced_message"] = synced_message
        sync_buffer["sycned_message_size"] = len(message_to_send)

    def uncompress(self, sync_buffer, neighbor_hat_params):
        sycned_message_size = int(sync_buffer["sycned_message_size"] / 2)

        # uncompress and update.
        for rank, hat_params in neighbor_hat_params.items():
            _message = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)
            values = _message[:sycned_message_size]
            indices = _message[sycned_message_size:]

            # deal with unbalanced values/indieces
            q_values, q_indices = self.compressor_fn.uncompress(
                values,
                indices,
                sync_buffer["selected_shapes"],
                sync_buffer["original_shapes"],
            )

            # update the flatten hat params.
            hat_params.buffer[q_indices] += q_values
Esempio n. 5
0
class DeepSqueezeSparsificationCompressor(object):
    def __init__(
        self,
        aggregator,
        rank,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        consensus_stepsize,
        **kargs
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.rank = rank
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.consensus_stepsize = consensus_stepsize
        self.kargs = kargs
        self.compressor_fn = SparsificationCompressor()

    def compress(self, sync_buffer):
        # get the sign/magnitude for the tensor (to be transmitted).
        selected_values, selected_indices = [], []

        # compress and get compressed model.
        local_compressed_params_tb = deepcopy(sync_buffer["params_tb"])
        local_compressed_params_tb.buffer = torch.zeros_like(
            local_compressed_params_tb.buffer
        )
        for param, local_compressed_param in zip(
            sync_buffer["params_tb"], local_compressed_params_tb
        ):
            _selected_values, _selected_indices = self.compressor_fn.compress(
                param, self.comm_op, self.compress_ratio, self.is_biased
            )
            selected_values.append(_selected_values)
            selected_indices.append(_selected_indices)

            # update the local compressed params.
            local_compressed_param.data = local_compressed_param.data.view(-1)
            local_compressed_param.data[_selected_indices] = _selected_values
            local_compressed_param.data.view(*param.size())

        # get selected shapes.
        selected_shapes = [len(_value) for _value in selected_values]

        # flatten selected values/indices.
        flatten_selected_values = TensorBuffer(selected_values)
        flatten_selected_indices = TensorBuffer(selected_indices)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_selected_values.buffer) + get_n_bits(
            flatten_selected_indices.buffer
        )

        # update shared dict.
        sync_buffer["selected_shapes"] = selected_shapes
        sync_buffer["flatten_selected_values"] = flatten_selected_values
        sync_buffer["flatten_selected_indices"] = flatten_selected_indices
        sync_buffer["n_bits"] = n_bits
        return local_compressed_params_tb

    def sync(self, sync_buffer):
        # get the flatten values.
        message_to_send = torch.cat(
            [
                sync_buffer["flatten_selected_values"].buffer,
                sync_buffer["flatten_selected_indices"].buffer,
            ]
        )

        if self.comm_device == "cpu":
            message_to_send = message_to_send.cpu().pin_memory()

        # sync.
        synced_message = self.aggregator_fn._agg(
            message_to_send, op="get_raw_sync_data", force_wait=True
        )

        # update sync_buffer.
        sync_buffer["synced_message"] = synced_message
        sync_buffer["sycned_message_size"] = len(message_to_send)

    def uncompress(self, sync_buffer, neighbors_info):
        aggregated_info_tb = deepcopy(sync_buffer["params_tb"])
        aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer)

        # uncompress and update.
        sycned_message_size = int(sync_buffer["sycned_message_size"] / 2)

        for rank in neighbors_info.keys():
            _message = comm.recover_device(
                sync_buffer["synced_message"][rank],
                device=sync_buffer["params_tb"].buffer.device,
            )
            values = _message[:sycned_message_size]
            indices = _message[sycned_message_size:]

            # deal with unbalanced values/indieces
            q_values, q_indices = self.compressor_fn.uncompress(
                values,
                indices,
                sync_buffer["selected_shapes"],
                sync_buffer["original_shapes"],
            )

            # update the flatten hat params.
            aggregated_info_tb.buffer[q_indices] += (
                self.consensus_stepsize
                * (neighbors_info[rank] - (1 if rank == self.rank else 0))
                * q_values
            )
        return aggregated_info_tb
Esempio n. 6
0
class CHOCOSparsificationCompressor(object):
    def __init__(
        self,
        aggregator,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        **kargs,
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = SparsificationCompressor()

        # define gossip_stream
        if self.comm_device == "cpu":
            self.gossip_stream = torch.cuda.current_stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()

    def pipeline(self, sync_buffer, neighbor_hat_params, neighbors_info):
        with torch.cuda.stream(self.gossip_stream):
            try:
                self.compress(sync_buffer)
                self.sync(sync_buffer)
                self.uncompress(sync_buffer, neighbor_hat_params,
                                neighbors_info)
            except RuntimeError as e:
                print("Error: {}".format(e))

    def compress(self, sync_buffer):
        selected_values, selected_indices = [], []

        for half_param, hat_param in zip(sync_buffer["flatten_params"],
                                         sync_buffer["flatten_hat_params"]):
            _selected_values, _selected_indices = self.compressor_fn.compress(
                half_param - hat_param,
                self.comm_op,
                self.compress_ratio,
                self.is_biased,
            )
            selected_values.append(_selected_values)
            selected_indices.append(_selected_indices)

        # get selected shapes.
        selected_shapes = [len(_value) for _value in selected_values]

        # flatten selected values/indices.
        flatten_selected_values = TensorBuffer(selected_values)
        flatten_selected_indices = TensorBuffer(selected_indices)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_selected_values.buffer) + get_n_bits(
            flatten_selected_indices.buffer)

        # update shared dict.
        sync_buffer["selected_shapes"] = selected_shapes
        sync_buffer["flatten_selected_values"] = flatten_selected_values
        sync_buffer["flatten_selected_indices"] = flatten_selected_indices
        sync_buffer["n_bits"] = n_bits

    def sync(self, sync_buffer):
        # get the flatten values and prepare the sync.
        message_to_send = torch.cat([
            sync_buffer["flatten_selected_values"].buffer,
            sync_buffer["flatten_selected_indices"].buffer,
        ])

        if self.comm_device == "cpu":
            message_to_send = message_to_send.cpu().pin_memory()

        # sync.
        sync_message_reqs, synced_message = self.aggregator_fn._agg(
            message_to_send, op="get_raw_sync_data", force_wait=False)

        # update sync_buffer.
        sync_buffer["sync_reqs"] = sync_message_reqs
        sync_buffer["synced_message"] = synced_message
        sync_buffer["sycned_message_size"] = len(message_to_send)

    def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info):
        # wait the sync.
        self.aggregator_fn.complete_wait(sync_buffer["sync_reqs"])

        # uncompress and update.
        message_size = int(sync_buffer["sycned_message_size"] / 2)

        for rank, weight in neighbors_info.items():
            hat_params = neighbor_hat_params[rank if rank in
                                             neighbor_hat_params else "memory"]
            hat_params_memory = neighbor_hat_params["memory"]

            # recover values/indices to the correct device.
            q_values, q_indices = self._uncompress_helper(
                hat_params,
                rank,
                sync_buffer["synced_message"],
                message_size,
                sync_buffer["selected_shapes"],
                sync_buffer["original_shapes"],
            )

            # update neighbor_hat_params
            if rank in neighbor_hat_params:
                hat_params.buffer[q_indices] += q_values
            hat_params_memory.buffer[q_indices] += weight * q_values

    def _uncompress_helper(
        self,
        _hat_params,
        _rank,
        synced_message,
        sycned_message_size,
        selected_shapes,
        original_shapes,
    ):
        # recover the message and the corresponding device.
        _message = comm.recover_device(synced_message[_rank],
                                       device=_hat_params.buffer.device)
        values = _message[:sycned_message_size]
        indices = _message[sycned_message_size:]

        # deal with unbalanced values/indieces
        q_values, q_indices = self.compressor_fn.uncompress(
            values, indices, selected_shapes, original_shapes)
        return q_values, q_indices