Ejemplo n.º 1
0
    def async_elastic_average(self, itr, cp, req):
        step_flag = (itr != 0 and itr % cp == 0)
        if step_flag:
            beta = 1 / self.size - self.alpha - self.alpha**2 / (1 -
                                                                 self.alpha)
            gamma = self.alpha / (1 - self.alpha)
            if req:
                req.wait()
                for f, t in zip(
                        unflatten_tensors(self.flat_tensor, self.comm_buf),
                        self.comm_buf):
                    t.set_(f)

            for group in self.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    buf = param_state['anchor_model']

                    p.data.mul_(1 - self.alpha).add_(self.alpha, buf)
                    buf.mul_(beta).add_(gamma, p.data)

            self.flat_tensor = flatten_tensors(self.comm_buf)
            req = dist.all_reduce(tensor=self.flat_tensor, async_op=True)

        return req
Ejemplo n.º 2
0
    def prepare_comm_buffer(self):
        # flatten tensors
        # If not initialized, then initialize x_hat and s
        self.x = flatten_tensors(self.tensor_list).cpu()
        if not self.initialized:
            self.x_hat = torch.zeros_like(self.x)
            self.s = torch.zeros_like(self.x)
            self.initialized = True

        tic = time.time()
        # get compressed message
        # here, we use top_k compressor on GPU
        # one can define more in compressors.py
        self.send_buffer = self.x - self.x_hat
        values, indices = get_top_k(self.send_buffer.cuda(), self.ratio)
        toc = time.time()

        values, indices = values.cpu(), indices.cpu()
        self.compressed = {"values": values, "indices": indices}

        return toc - tic
Ejemplo n.º 3
0
    def async_BMUF(self, itr, cp, req):
        step_flag = (itr != 0 and itr % cp == 0)
        if step_flag:
            if req:
                req.wait()
                for f, t in zip(
                        unflatten_tensors(self.flat_tensor, self.comm_buf),
                        self.comm_buf):
                    t.set_(f)

            for group in self.param_groups:
                lr = group['lr']
                for p in group['params']:
                    param_state = self.state[p]
                    old_data = param_state['anchor_model']

                    p.data.mul_(1 - self.alpha).add_(self.alpha, old_data)

                    if 'global_momentum_buffer' not in param_state:
                        buf = param_state[
                            'global_momentum_buffer'] = torch.clone(
                                p.data).detach()
                        buf.sub_(old_data)
                        buf.div_(-lr)
                    else:
                        buf = param_state['global_momentum_buffer']
                        buf.mul_(self.gmf).sub_(1 / lr,
                                                p.data).add_(1 / lr, old_data)

                    old_data.add_(-lr, buf)
                    old_data.div_(self.size)

            self.flat_tensor = flatten_tensors(self.comm_buf)
            req = dist.all_reduce(tensor=self.flat_tensor, async_op=True)

        return req
Ejemplo n.º 4
0
 def prepare_comm_buffer(self):
     # faltten tensors
     self.send_buffer = flatten_tensors(self.tensor_list).cpu()
     self.recv_buffer = torch.zeros_like(self.send_buffer)
Ejemplo n.º 5
0
 def prepare_comm_buffer(self):
     # faltten tensors
     self.send_buffer = flatten_tensors(self.tensor_list).cpu()