コード例 #1
0
ファイル: deep_squeeze.py プロジェクト: VhalPurohit/CHOCOSGD
 def __init__(
     self,
     aggregator,
     rank,
     comm_op,
     comm_device,
     compress_ratio,
     quantize_level,
     is_biased,
     backend,
     use_ipc,
     consensus_stepsize,
     **kargs
 ):
     # assign the common hyper-parameters
     self.aggregator_fn = aggregator
     self.rank = rank
     self.comm_op = comm_op
     self.comm_device = comm_device
     self.compress_ratio = compress_ratio
     self.quantize_level = quantize_level
     self.is_biased = is_biased
     self.backend = backend
     self.use_ipc = use_ipc
     self.consensus_stepsize = consensus_stepsize
     self.kargs = kargs
     self.compressor_fn = QuantizationCompressor()
コード例 #2
0
ファイル: parallel_choco.py プロジェクト: stjordanis/ChocoSGD
    def __init__(
        self,
        aggregator,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        **kargs,
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

        # define gossip_stream
        if self.comm_device == "cpu":
            self.gossip_stream = torch.cuda.current_stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()
コード例 #3
0
class ECDQuantizationCompressor(object):
    def __init__(self, aggregator, comm_op, comm_device, compress_ratio,
                 quantize_level, is_biased, backend, use_ipc, **kargs):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

    def compress(self, sync_buffer):
        # get the sign/magnitude for the tensor (to be transmitted).
        quantized_values = []

        for flatten_updated_param in sync_buffer["flatten_updated_params"]:
            _quantized_values = self.compressor_fn.compress(
                flatten_updated_param, self.comm_op, self.quantize_level,
                self.is_biased)
            quantized_values.append(_quantized_values)

        # flatten selected values/indices.
        flatten_updates = TensorBuffer(quantized_values)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32

        # update shared dict.
        sync_buffer["flatten_updates"] = flatten_updates
        sync_buffer["n_bits"] = n_bits

    def sync(self, sync_buffer):
        # prepare the sync.
        to_sync_message = sync_buffer["flatten_updates"].buffer

        if self.comm_device == "cpu":
            to_sync_message = to_sync_message.cpu().pin_memory()

        # sync.
        synced_message = self.aggregator_fn._agg(to_sync_message,
                                                 op="get_raw_sync_data",
                                                 force_wait=True)

        # update sync_buffer.
        sync_buffer["synced_message"] = synced_message

    def uncompress(self, sync_buffer, neighbor_hat_params, local_index):
        # uncompress and update.
        for rank, hat_params in neighbor_hat_params.items():
            # map the tensors to the correct location.
            _message = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)

            # update the flatten hat params.
            hat_params.buffer.mul_(1 - 2 / local_index).add_(
                2 / local_index, _message)
コード例 #4
0
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(DGC, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf
        self.n_nodes = conf.graph.n_nodes
        self.rank = conf.graph.rank

        # define the aggregator.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))
        self.world_aggregator = get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # related to sparsification/quantization.
        self.comm_op = conf.comm_op
        self.comm_device = conf.comm_device
        self.is_compress_op = "compress" in self.comm_op
        self.compress_ratio = conf.compress_ratio
        self.compress_warmup_values = conf.compress_warmup_values
        self.compress_warmup_epochs = conf.compress_warmup_epochs
        self.quantize_level = conf.quantize_level
        self.is_biased = conf.is_biased

        self.clip_grad = conf.clip_grad
        self.clip_grad_val = conf.clip_grad_val
        self.mask_momentum = conf.mask_momentum

        self.init_memory()
        self.init_compression()

        # define compressors.
        if self.is_compress_op:
            self.compressor_fn = SparsificationCompressor()
        else:
            self.compressor_fn = QuantizationCompressor()

        # define reducer.
        self.backend = conf.backend
コード例 #5
0
class DGC(Optimizer):
    def __init__(
        self,
        params,
        lr=required,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        conf=None,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(DGC, self).__init__(params, defaults)

        # store the whole training arguments.
        self.conf = conf
        self.n_nodes = conf.graph.n_nodes
        self.rank = conf.graph.rank

        # define the aggregator.
        self.param_names = list(
            enumerate([group["name"] for group in self.param_groups]))
        self.world_aggregator = get_aggregators(
            cur_rank=self.rank,
            world=conf.graph.ranks,
            neighbors_info=dict(
                (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks),
            aggregator_type="centralized",
        )

        # related to sparsification/quantization.
        self.comm_op = conf.comm_op
        self.comm_device = conf.comm_device
        self.is_compress_op = "compress" in self.comm_op
        self.compress_ratio = conf.compress_ratio
        self.compress_warmup_values = conf.compress_warmup_values
        self.compress_warmup_epochs = conf.compress_warmup_epochs
        self.quantize_level = conf.quantize_level
        self.is_biased = conf.is_biased

        self.clip_grad = conf.clip_grad
        self.clip_grad_val = conf.clip_grad_val
        self.mask_momentum = conf.mask_momentum

        self.init_memory()
        self.init_compression()

        # define compressors.
        if self.is_compress_op:
            self.compressor_fn = SparsificationCompressor()
        else:
            self.compressor_fn = QuantizationCompressor()

        # define reducer.
        self.backend = conf.backend

    def init_memory(self):
        self.memory_of_grads = dict()

        for group in self.param_groups:
            for p in group["params"]:
                self.memory_of_grads[group["name"]] = torch.zeros_like(
                    p.data).view(-1)

    def init_compression(self):
        # configure gradient warmup values
        if self.compress_ratio is not None:
            compress_warmup_values = [
                float(value)
                for value in self.compress_warmup_values.split(",")
            ]
            self.compress_warmup_values = [
                value for value in compress_warmup_values
                if value <= self.compress_ratio
            ]

            num_compress_warmup_values = len(self.compress_warmup_values)
            self.detailed_compress_warmup_epochs = [
                1.0 * ind / num_compress_warmup_values *
                self.compress_warmup_epochs
                for ind in range(1, num_compress_warmup_values + 1)
            ]

    def __setstate__(self, state):
        super(DGC, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault("nesterov", False)

    def step(self, closure=None, **kargs):
        # apply local gradient.
        with kargs["timer"]("grad.apply_grad", epoch=self.conf.epoch_):
            self._apply_gradient()

        # Unflatten the saved hat params.
        with kargs["timer"]("grad.recover_hat_params", epoch=self.conf.epoch_):
            params, _ = get_data(self.param_groups,
                                 self.param_names,
                                 is_get_grad=False)
            grads, shapes = get_data(self.param_groups,
                                     self.param_names,
                                     is_get_grad=True)

        # compress.
        with kargs["timer"]("grad.compress", epoch=self.conf.epoch_):
            selected_values, selected_indices, n_bits = self._compress(grads)

        # sync.
        with kargs["timer"]("grad.sync", epoch=self.conf.epoch_):
            synced_message, message_size = self._sync(selected_values,
                                                      selected_indices)

        # recover and update the neighbor hat params.
        with kargs["timer"]("grad.recover_info", epoch=self.conf.epoch_):
            updated_flatten_params = self._recover_info(
                flatten(params),
                synced_message,
                message_size,
                self.selected_shapes,
                shapes,
            )

        with kargs["timer"]("grad.update_model", epoch=self.conf.epoch_):
            # finally unflatten.
            unflatten(params, updated_flatten_params, shapes)
        return n_bits

    def _compress(self, grads):
        selected_values, selected_indices, n_bits = [], [], []
        for (idx, param_name), grad in zip(self.param_names, grads):
            # add memory back.
            _grad = grad.data.view(-1) + self.memory_of_grads[param_name]

            # get values and indices
            compress_ratio = self._get_compress_ratio()
            _selected_values, _selected_indices, _n_bits = compress_or_quantize(
                grad=_grad,
                comm_op=self.comm_op,
                compressor_fn=self.compressor_fn,
                compress_ratio=compress_ratio,
                quantize_level=self.quantize_level,
                is_biased=self.is_biased,
            )
            selected_values.append(_selected_values)
            selected_indices.append(_selected_indices)
            n_bits.append(_n_bits)

            # update the memory
            if self.is_compress_op:
                _, nmask = self.compressor_fn.get_mask(_grad,
                                                       _selected_indices)
                self.memory_of_grads[param_name] = _grad * nmask

                # apply momentum factor masking.
                if self.mask_momentum:
                    self.state[self.param_groups[idx]["params"]
                               [0]]["momentum_buffer"].mul_(
                                   nmask.view(grad.size()))
            else:
                # self.memory_of_grads[param_name] = _grad - _selected_values
                pass

        # get selected shapes.
        self.selected_shapes = [len(_value) for _value in selected_values]

        # flatten selected values/indices.
        flatten_selected_values = flatten(selected_values)
        flatten_selected_indices = (flatten(selected_indices) if
                                    selected_indices[0] is not None else None)
        return flatten_selected_values, flatten_selected_indices, sum(n_bits)

    def _sync(self, selected_values, selected_indices):
        if self.is_compress_op:
            # concate values and indices.
            message_to_send = torch.cat([selected_values, selected_indices])

            if self.comm_device == "cpu":
                message_to_send = message_to_send.cpu().pin_memory()

            synced_message = self.world_aggregator._agg(
                message_to_send, communication_scheme="all_gather")
        else:
            message_to_send = selected_values

            if self.comm_device == "cpu":
                message_to_send = message_to_send.cpu().pin_memory()

            synced_message = self.world_aggregator._agg(
                message_to_send, op="sum", communication_scheme="all_reduce")

        # get message size.
        message_size = len(message_to_send)
        return synced_message, message_size

    def _recover_info(self, flatten_params, synced_message, message_size,
                      selected_shapes, shapes):
        # use the pointers to recover the info and get synced grad.
        _message_size = int(message_size / 2)

        if self.is_compress_op:
            empty_grads = torch.zeros_like(flatten_params)

            for message in synced_message:
                q_values, q_indices = self.compressor_fn.uncompress(
                    message[:_message_size],
                    message[_message_size:],
                    selected_shapes,
                    shapes,
                )

                empty_grads[q_indices] += q_values

            # get update tensor.
            _update = empty_grads / self.n_nodes
        else:
            # get update tensor.
            _update = synced_message / self.n_nodes

        # update flatten_params (assume the used lr is the same over params)
        updated_flatten_params = flatten_params.add(
            -self.param_groups[0]["lr"],
            recover_device(_update, device=flatten_params.device),
        )
        return updated_flatten_params

    def _clip_gradient(self, grad, param_state, scale=True):
        # calculate the grad norm.
        grad_norm = grad.norm(p=2)

        threshold = self.clip_grad_val
        if threshold is None:
            return grad
        else:
            threshold *= 1.0 if not scale else 1.0 / math.sqrt(self.n_nodes)
            if grad_norm >= threshold:
                grad = threshold / grad_norm * grad
            return grad

    def _get_compress_ratio(self):
        # if we are under the phase of warmup, use different dgc ratio,
        # otherwise return the expected one.
        if self.is_compress_op:
            if self.conf.epoch_ < self.compress_warmup_epochs:
                for ind, val in enumerate(
                        self.detailed_compress_warmup_epochs):
                    if self.conf.epoch_ < val:
                        return self.compress_warmup_values[ind]
            return self.compress_ratio
        else:
            return None

    def _apply_gradient(self):
        """Performs a single optimization step.

        Avoid to use momentum to accumulate the gradients from other workers.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        for group in self.param_groups:
            # retrieve para.
            weight_decay = group["weight_decay"]
            momentum = group["momentum"]
            dampening = group["dampening"]
            nesterov = group["nesterov"]

            for p in group["params"]:
                # get param_state
                param_state = self.state[p]

                # get the gradient
                if p.grad is None:
                    continue
                d_p = p.grad.data

                # add the weight decay.
                if weight_decay != 0:
                    d_p.add_(p.data, alpha=weight_decay)

                # clip the gradient.
                if self.clip_grad:
                    d_p = self._clip_gradient(d_p, param_state)

                # apply the momentum.
                if momentum != 0:
                    if "momentum_buffer" not in param_state:
                        buf = param_state[
                            "momentum_buffer"] = torch.zeros_like(d_p)
                        buf.add_(d_p)
                    else:
                        buf = param_state["momentum_buffer"]
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf
                p.grad.data = d_p
コード例 #6
0
ファイル: deep_squeeze.py プロジェクト: VhalPurohit/CHOCOSGD
class DeepSqueezeQuantizationCompressor(object):
    def __init__(
        self,
        aggregator,
        rank,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        consensus_stepsize,
        **kargs
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.rank = rank
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.consensus_stepsize = consensus_stepsize
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

    def compress(self, sync_buffer):
        # get the sign/magnitude for the tensor (to be transmitted).
        quantized_values = []

        # compress and get compressed model.
        local_compressed_params_tb = deepcopy(sync_buffer["params_tb"])
        local_compressed_params_tb.buffer = torch.zeros_like(
            local_compressed_params_tb.buffer
        )
        for param, local_compressed_param in zip(
            sync_buffer["params_tb"], local_compressed_params_tb
        ):
            # quantize.
            _quantized_values = self.compressor_fn.compress(
                param, self.comm_op, self.quantize_level, self.is_biased
            )
            quantized_values.append(_quantized_values)

            # update the local compressed params.
            local_compressed_param.data.copy_(_quantized_values)

        # flatten selected values/indices.
        flatten_updates = TensorBuffer(quantized_values)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32

        # update shared dict.
        sync_buffer["flatten_updates"] = flatten_updates
        sync_buffer["n_bits"] = n_bits
        return local_compressed_params_tb

    def sync(self, sync_buffer):
        # prepare the sync.
        to_sync_message = sync_buffer["flatten_updates"].buffer

        if self.comm_device == "cpu":
            to_sync_message = to_sync_message.cpu().pin_memory()

        # sync.
        synced_message = self.aggregator_fn._agg(
            to_sync_message, op="get_raw_sync_data", force_wait=True
        )

        # update sync_buffer.
        sync_buffer["synced_message"] = synced_message

    def uncompress(self, sync_buffer, neighbors_info):
        aggregated_info_tb = deepcopy(sync_buffer["params_tb"])
        aggregated_info_tb.buffer = torch.zeros_like(aggregated_info_tb.buffer)

        # uncompress and update.
        for rank in neighbors_info.keys():
            # map the tensors to the correct location.
            _message = comm.recover_device(
                sync_buffer["synced_message"][rank],
                device=sync_buffer["params_tb"].buffer.device,
            )

            # update the flatten hat params.
            aggregated_info_tb.buffer.add_(
                self.consensus_stepsize
                * (neighbors_info[rank] - (1 if rank == self.rank else 0))
                * _message
            )
        return aggregated_info_tb
コード例 #7
0
ファイル: parallel_choco.py プロジェクト: stjordanis/ChocoSGD
class CHOCOQuantizationCompressor(object):
    def __init__(
        self,
        aggregator,
        comm_op,
        comm_device,
        compress_ratio,
        quantize_level,
        is_biased,
        backend,
        use_ipc,
        **kargs,
    ):
        # assign the common hyper-parameters
        self.aggregator_fn = aggregator
        self.comm_op = comm_op
        self.comm_device = comm_device
        self.compress_ratio = compress_ratio
        self.quantize_level = quantize_level
        self.is_biased = is_biased
        self.backend = backend
        self.use_ipc = use_ipc
        self.kargs = kargs
        self.compressor_fn = QuantizationCompressor()

        # define gossip_stream
        if self.comm_device == "cpu":
            self.gossip_stream = torch.cuda.current_stream()
        else:
            self.gossip_stream = torch.cuda.current_stream()

    def pipeline(self, sync_buffer, neighbor_hat_params, neighbors_info):
        with torch.cuda.stream(self.gossip_stream):
            try:
                self.compress(sync_buffer)
                self.sync(sync_buffer)
                self.uncompress(sync_buffer, neighbor_hat_params,
                                neighbors_info)
            except RuntimeError as e:
                print("Error: {}".format(e))

    def compress(self, sync_buffer):
        quantized_values = []

        for half_param, hat_param in zip(sync_buffer["flatten_params"],
                                         sync_buffer["flatten_hat_params"]):
            _quantized_values = self.compressor_fn.compress(
                half_param - hat_param,
                self.comm_op,
                self.quantize_level,
                self.is_biased,
            )
            quantized_values.append(_quantized_values)

        # flatten selected values/indices.
        flatten_updates = TensorBuffer(quantized_values)

        # get n_bits to transmit.
        n_bits = get_n_bits(flatten_updates.buffer) * self.quantize_level / 32

        # update shared dict.
        sync_buffer["flatten_updates"] = flatten_updates
        sync_buffer["n_bits"] = n_bits

    def sync(self, sync_buffer):
        # prepare the sync.
        to_sync_message = sync_buffer["flatten_updates"].buffer

        if self.comm_device == "cpu":
            to_sync_message = to_sync_message.cpu().pin_memory()

        # sync.
        sync_message_reqs, synced_message = self.aggregator_fn._agg(
            to_sync_message, op="get_raw_sync_data", force_wait=False)

        # update sync_buffer.
        sync_buffer["sync_reqs"] = sync_message_reqs
        sync_buffer["synced_message"] = synced_message

    def uncompress(self, sync_buffer, neighbor_hat_params, neighbors_info):
        # wait the sync.
        self.aggregator_fn.complete_wait(sync_buffer["sync_reqs"])

        for rank, weight in neighbors_info.items():
            hat_params = neighbor_hat_params[rank if rank in
                                             neighbor_hat_params else "memory"]
            hat_params_memory = neighbor_hat_params["memory"]

            # recover correct values/indices.
            q_values = comm.recover_device(sync_buffer["synced_message"][rank],
                                           device=hat_params.buffer.device)

            # update neighbor_hat_params
            if rank in neighbor_hat_params:
                hat_params.buffer += q_values
            hat_params_memory.buffer += weight * q_values