Esempio n. 1
0
        def hook(*ignore):
            assert p not in self._handles
            assert not p.grad.requires_grad
            name = self._parameter_names.get(p)
            p_size = np.prod(p.size())
            torch.cuda.synchronize()
            begin_time =  time.time()

            if self._use_allgather and p_size > self._plan1:
                torch.cuda.synchronize()
                begin_mom_time =  time.time()

                weight_decay = self._weight_decay #group['weight_decay']
                momentum = self._momentum #group['momentum']
                dampening = 0.0 #group['dampening']
                d_p = p.grad.data
                d_p.div_(hvd.size())
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    #TODO
                if 'residue_buffer' not in param_state:
                    rsd = param_state['residue_buffer'] = torch.zeros_like(p.data)
                    rsd.add_(param_state['momentum_buffer'])
                    if self._use_nesterov:
                        rsd  = rsd.add(momentum, d_p)
                else:
                    rsd = param_state['residue_buffer']
                    rsd.add_(param_state['momentum_buffer'])
                    if self._use_nesterov:
                        rsd  = rsd.add(momentum, d_p)

                torch.cuda.synchronize()
                self.mom_time += time.time() - begin_mom_time

                compressed_val = []
                compressed_idx = []

                torch.cuda.synchronize()
                begin_select_time =  time.time()

                if 'flag' not in param_state:
                    param_state['flag'] = 0
                if 'interval' not in param_state:
                    param_state['interval'] = 10
                it = 0
                sparsity = 0.0

                if p_size > self._plan3:
                    if param_state['flag'] == 1:
                        compressed_val, compressed_idx, it, _, sparsity = \
                            select_bs_top(param_state['residue_buffer'], 0.001)
                        param_state['flag'] = 0
                    else:
                        compressed_val, compressed_idx, it, _, sparsity = \
                            select_bs_bottom(param_state['residue_buffer'], 0.001)
                        param_state['flag'] = 1
                elif p_size > self._plan2:
                    if param_state['flag'] == 1:
                        compressed_val, compressed_idx = \
                            select_trim_topk_mean(param_state['residue_buffer'], 0.001)
                        param_state['flag'] = 0
                    else:
                        compressed_val, compressed_idx = \
                            select_trim_lowk_mean(param_state['residue_buffer'], 0.001)
                        param_state['flag'] = 1
                else:
                    if param_state['flag'] == 1:
                        compressed_val, compressed_idx = \
                            select_topk_mean(param_state['residue_buffer'], 0.001)
                        param_state['flag'] = 0
                    else:
                        compressed_val, compressed_idx = \
                            select_lowk_mean(param_state['residue_buffer'], 0.001)
                        param_state['flag'] = 1

                assert(len(compressed_idx) > 0)
                torch.cuda.synchronize()
                end_select_time =  time.time()
                self.select_time += end_select_time - begin_select_time
                #if param_state['interval'] == 10:
                #    compressed_val, compressed_idx, it, param_state['mid_store'], sparsity = \
                #            select_top_k_thdv3(param_state['residue_buffer'], 0.001)
                #    param_state['interval'] = 0
                #else:
                #    compressed_val, compressed_idx, sparsity = \
                #            select_top_k_fixthd(param_state['residue_buffer'], param_state['mid_store'])
                #    param_state['interval'] += 1
                #if hvd.rank() == 0:
                #    print(name, p.size())
                #if hvd.rank() == 0 and name == "features.27.weight":
                #if name == "features.27.weight":
                #    torch.save(compressed_val, 'compressed_val' + str(local_rank()))
                #    torch.save(compressed_idx, 'compressed_idx' + str(local_rank()))
                #if hvd.rank() == 0 and name == "features.27.weight":
                #    self._it = it
                #    self._mid = param_state['mid_store']
                #    self._sparsity = sparsity
                #tmp_t = torch.tensor([local_len], dtype=torch.long)
#                tmp_t = torch.tensor([local_len])
                # print("len list, ", global_len_list)
                #local_len = torch.min(global_len_list)
                ##print("local_len, ", local_len)
                #compressed_val = compressed_val[0:local_len]
                #compressed_idx = compressed_idx[0:local_len]

                torch.cuda.synchronize()
                begin_mask_time =  time.time()

                masks_size = self._masks[name].size()
                self._masks[name].zero_()
                self._masks[name] = self._masks[name].view(-1)
                self._masks[name][compressed_idx] = 1.0

                self._masks[name] = 1.0 - self._masks[name]
                self._masks[name] = self._masks[name].view(masks_size)

                if self._debug:
                    self._v_ref[name] = param_state['residue_buffer'] * (1.0 - self._masks[name])
                    allreduce_(self._v_ref[name], average = False)


                if hvd.size() == 1:
                    p.grad.data = param_state['residue_buffer'] * (1.0 - self._masks[name])

                param_state['residue_buffer'].mul_(self._masks[name])
                param_state['momentum_buffer'].mul_(self._masks[name])

                end_mask_time =  time.time()
                self.mask_time += end_mask_time - begin_mask_time

                torch.cuda.synchronize()
                begin_pack_time =  time.time()

                if hvd.size() > 1:
                    if self._use_gpu:
                        if p_size > self._plan3:
                            compressed_msg= torch.cat((\
                                torch.tensor([len(compressed_idx)]).type(torch.cuda.LongTensor),\
                                compressed_idx))
                            handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name + "idx")
                            self._handles[p] = handle

                            handle = _allgather_async(torch.mean(compressed_val), self._compressed_val[name], name=name + "val")
                            self._handles_val[p] = handle
                        else:
                            self._compressed_msg_size[name] = len(compressed_idx)
                            handle = _allgather_async(compressed_idx, self._compressed_idx[name], \
                                    name = name+"idx")
                            self._handles[p] = handle
                            handle = _allgather_async(torch.mean(compressed_val), \
                                    self._compressed_val[name], name=name+"val")
                            self._handles_val[p] = handle
                torch.cuda.synchronize()
                self.pack_time += time.time() - begin_pack_time
            else:
                torch.cuda.synchronize()
                begin_allreduce_time =  time.time()
                p.grad.data.div_(hvd.size())
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    param_state['momentum_buffer'] = torch.zeros_like(p.data)
                if self._use_nesterov:
                    param_state['momentum_buffer'] = torch.mul(torch.add(param_state['momentum_buffer'], p.grad.data), self._momentum)
                    p.grad.data = param_state['momentum_buffer'] + p.grad.data
                else:
                    param_state['momentum_buffer']= self._momentum * param_state['momentum_buffer'] + p.grad.data
                    p.grad.data = param_state['momentum_buffer']
                if hvd.size() > 1:
                    handle = allreduce_async_(p.grad.data, average=False, name=name)
                    self._handles[p] = handle
                torch.cuda.synchronize()
                self.allreduce_time += time.time() - begin_allreduce_time

            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time
Esempio n. 2
0
    def step(self, closure=None):
        # local clipping
        # DGC
        for group in self.param_groups:
            for p in group['params']:
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                p.grad.data.div_(hvd.size())

            torch.nn.utils.clip_grad_norm_(group['params'],
                                           0.25 * hvd.size()**-0.5)
            #torch.nn.utils.clip_grad_norm(group['params'], 0.25)
            #weight_decay = group['weight_decay']
            #momentum = group['momentum']
            torch.cuda.synchronize()
            begin_time = time.time()

            dampening = 0.0  #gcoup['dampening']
            for p in group['params']:
                assert p not in self._handles
                assert not p.grad.requires_grad
                name = self._parameter_names.get(p)
                p_size = np.prod(p.size())
                if self._use_allgather and p_size > 1024:
                    param_state = self.state[p]
                    self._V[name].add_(p.grad.data)
                    # fjr compress grad
                    compressed_val = []
                    compressed_idx = []

                    torch.cuda.synchronize()
                    begin_select_time = time.time()
                    if 'interval' not in param_state:
                        param_state['interval'] = 1
                    if param_state['interval'] == 0:
                        compressed_val, compressed_idx, _, _, _ = \
                            select_bs_top(self._V[name], 0.001)
                        param_state['interval'] = 1
                    else:
                        compressed_val, compressed_idx, _, _, _ = \
                            select_bs_bottom(self._V[name], 0.001)
                        param_state['interval'] = 0

                    #masks_size = self._masks[name].size()
                    #self._masks[name].zero_()
                    #self._masks[name] = self._masks[name].view(-1)
                    #self._masks[name][compressed_idx] = 1.0
                    #self._masks[name] = 1.0 - self._masks[name]
                    #self._masks[name] = self._masks[name].view(masks_size)
                    torch.cuda.synchronize()
                    end_select_time = time.time()
                    self.select_time += end_select_time - begin_select_time

                    if self._debug:
                        self._v_ref[name] = self._V[name] * (1.0 -
                                                             self._masks[name])
                        allreduce_(self._v_ref[name], average=False)

                    torch.cuda.synchronize()
                    begin_mask_time = time.time()
                    V_size = self._masks[name].size()
                    self._V[name] = self._V[name].view(-1)
                    self._V[name][compressed_idx] = 0.0
                    self._V[name] = self._V[name].view(V_size)

                    torch.cuda.synchronize()
                    self.mask_time += time.time() - begin_mask_time
                    begin_pack_time = time.time()

                    self._compressed_msg_size[name] = len(compressed_idx)
                    if self._use_gpu:
                        compressed_msg = torch.cat([\
                            torch.tensor([len(compressed_idx)]).type('torch.cuda.LongTensor'),\
                            compressed_idx])

                    if p_size == 1500 * 10000:
                        compressed_msg = torch.cat([\
                            torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'),\
                            compressed_idx.type('torch.cuda.FloatTensor'), \
                            compressed_val])
                        handle = _allgather_async(compressed_msg,
                                                  self._compressed_msg[name],
                                                  name=name)
                        self._handles[p] = handle
                    else:
                        handle = _allgather_async(compressed_msg,
                                                  self._compressed_idx[name],
                                                  name=name + "idx")
                        self._handles[p] = handle
                        handle = _allgather_async(torch.mean(compressed_val), \
                                self._compressed_val[name], name=name+"val")
                        self._handles_val[p] = handle

                    torch.cuda.synchronize()
                    self.pack_time += time.time() - begin_pack_time

                else:
                    handle = allreduce_async_(p.grad.data,
                                              average=True,
                                              name=name)
                    self._handles[p] = handle

            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time

        self.synchronize()
        return super(self.__class__, self).step(closure)
Esempio n. 3
0
    def step(self, closure=None):
        # local clipping
        # DGC
        for group in self.param_groups:
            for p in group['params']:
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                p.grad.data.div_(hvd.size())

            torch.nn.utils.clip_grad_norm(group['params'],
                                          0.25 * hvd.size()**-0.5)
            #torch.nn.utils.clip_grad_norm(group['params'], 0.25)
            #weight_decay = group['weight_decay']
            #momentum = group['momentum']
            torch.cuda.synchronize()
            begin_time = time.time()

            dampening = 0.0  #gcoup['dampening']
            for p in group['params']:
                assert p not in self._handles
                assert not p.grad.requires_grad
                name = self._parameter_names.get(p)
                p_size = np.prod(p.size())
                if self._use_allgather and p_size > 1024:
                    if self._momentum != 0:
                        if self._use_nesterov:
                            self._U[name] = torch.mul(
                                torch.add(self._U[name], p.grad.data),
                                self._momentum)
                            self._V[name] = self._V[name] + self._U[
                                name] + p.grad.data
                        else:
                            self._U[name] = self._momentum * self._U[
                                name] + p.grad.data
                            self._V[name] = self._V[name] + self._U[name]
                    else:
                        self._V[name].add_(p.grad.data)
                    compressed_val = []
                    compressed_idx = []

                    torch.cuda.synchronize()
                    begin_select_time = time.time()

                    #compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001)
                    if self._flag[name] == 0:
                        compressed_val, compressed_idx = \
                                select_lowk_truncated_mean(self._V[name], 0.001)
                        self._flag[name] = 1
                    else:
                        compressed_val, compressed_idx = \
                                select_topk_truncated_mean(self._V[name], 0.001)
                        self._flag[name] = 0
                    #compressed_val_low, compressed_idx_low = \
                    #        select_lowk_truncated_mean(self._V[name], 0.001)
                    #compressed_val_top, compressed_idx_top = \
                    #        select_topk_truncated_mean(self._V[name], 0.001)
                    #compressed_mean = 0
                    #if(-torch.mean(compressed_val_low) > torch.mean(compressed_val_top)):
                    #    compressed_val = compressed_val_low
                    #    compressed_idx = compressed_idx_low
                    #    compressed_mean = torch.mean(compressed_val_low)
                    #else:
                    #    compressed_val = compressed_val_top
                    #    compressed_idx = compressed_idx_top
                    #    compressed_mean = torch.mean(compressed_val_top)

                    torch.cuda.synchronize()
                    end_select_time = time.time()
                    self.select_time += end_select_time - begin_select_time
                    if self._debug:
                        masks_size = self._masks[name].size()
                        self._masks[name].zero_()
                        self._masks[name] = self._masks[name].view(-1)
                        self._masks[name][compressed_idx] = 1.0
                        self._masks[name] = 1.0 - self._masks[name]
                        self._masks[name] = self._masks[name].view(masks_size)
                        self._v_ref[name] = self._V[name] * (1.0 -
                                                             self._masks[name])
                        allreduce_(self._v_ref[name], average=False)

                    #self._V[name].mul_(self._masks[name])

                    V_size = self._masks[name].size()
                    self._V[name] = self._V[name].view(-1)
                    self._V[name][compressed_idx] = 0.0
                    self._V[name] = self._V[name].view(V_size)

                    if self._momentum != 0.0:
                        self._U[name].mul_(self._masks[name])

                    torch.cuda.synchronize()
                    begin_comm_time = time.time()
                    self._compressed_msg_size[name] = len(compressed_idx)

                    handle = _allgather_async(compressed_idx,
                                              self._compressed_idx[name],
                                              name=name + "idx")
                    self._handles[p] = handle
                    handle = _allgather_async(torch.mean(compressed_val), \
                            self._compressed_val[name], name=name+"val")
                    self._handles_val[p] = handle
                else:
                    if self._weight_decay != 0.0:
                        p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                    if self._momentum != 0:
                        if self._use_nesterov:
                            self._U[name] = torch.mul(
                                torch.add(self._U[name], p.grad.data),
                                self._momentum)
                            self._V[name] = self._V[name] + self._U[
                                name] + p.grad.data
                        else:
                            self._U[name] = self._momentum * self._U[
                                name] + p.grad.data
                            self._V[name] = self._V[name] + self._U[name]
                        p.grad.data = self._V[name]
                    #compressed_msg = torch.randn(100).cuda()
                    #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name)
                    torch.cuda.synchronize()
                    begin_comm_time = time.time()

                    handle = allreduce_async_(p.grad.data,
                                              average=True,
                                              name=name)
                    self._handles[p] = handle

                    torch.cuda.synchronize()
                    self.comm_time += time.time() - begin_comm_time

            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time

        self.synchronize()
        return super(self.__class__, self).step(closure)
Esempio n. 4
0
        def hook(*ignore):
            assert p not in self._handles
            assert not p.grad.requires_grad
            name = self._parameter_names.get(p)
            p_size = np.prod(p.size())
            torch.cuda.synchronize()
            begin_time = time.time()

            if self._use_allgather and p_size > 1024:
                weight_decay = self._weight_decay  #group['weight_decay']
                momentum = self._momentum  #group['momentum']
                dampening = 0.0  #group['dampening']
                nesterov = False  #group['nesterov']
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state[
                            'momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    #TODO
                    # if nesterov:
                    #     d_p = d_p.add(momentum, buf)
                    # else:
                    #     d_p = buf
                if 'residue_buffer' not in param_state:
                    rsd = param_state['residue_buffer'] = torch.zeros_like(
                        p.data)
                    rsd.add_(param_state['momentum_buffer'])
                else:
                    rsd = param_state['residue_buffer']
                    rsd.add_(param_state['momentum_buffer'])

                compressed_val = []
                compressed_idx = []

                torch.cuda.synchronize()
                begin_select_time = time.time()
                if 'mid_store' not in param_state:
                    param_state['mid_store'] = 0.0
                if 'interval' not in param_state:
                    param_state['interval'] = 10
                it = 0
                sparsity = 0.0
                if param_state['interval'] == 10:
                    compressed_val, compressed_idx, it, param_state['mid_store'], sparsity = \
                            select_top_k_thdv3(param_state['residue_buffer'], 0.001)
                    param_state['interval'] = 0
                else:
                    compressed_val, compressed_idx, sparsity = \
                            select_top_k_fixthd(param_state['residue_buffer'], param_state['mid_store'])
                    param_state['interval'] += 1
                assert (len(compressed_idx) > 0)
                #if hvd.rank() == 0:
                #    print(name, p.size())
                #if hvd.rank() == 0 and name == "features.27.weight":
                #if name == "features.27.weight":
                #    torch.save(compressed_val, 'compressed_val' + str(local_rank()))
                #    torch.save(compressed_idx, 'compressed_idx' + str(local_rank()))
                #if hvd.rank() == 0 and name == "features.27.weight":
                #    self._it = it
                #    self._mid = param_state['mid_store']
                #    self._sparsity = sparsity
                torch.cuda.synchronize()
                end_select_time = time.time()
                self.select_time += end_select_time - begin_select_time
                #tmp_t = torch.tensor([local_len], dtype=torch.long)
                #                tmp_t = torch.tensor([local_len])
                # print("len list, ", global_len_list)
                #local_len = torch.min(global_len_list)
                ##print("local_len, ", local_len)
                #compressed_val = compressed_val[0:local_len]
                #compressed_idx = compressed_idx[0:local_len]

                masks_size = self._masks[name].size()
                self._masks[name].zero_()
                self._masks[name] = self._masks[name].view(-1)
                self._masks[name][compressed_idx] = 1.0

                self._masks[name] = 1.0 - self._masks[name]
                self._masks[name] = self._masks[name].view(masks_size)

                if self._debug:
                    self._v_ref[name] = param_state['residue_buffer'] * (
                        1.0 - self._masks[name])
                    allreduce_(self._v_ref[name], average=False)

                #self._V[name] = self._V[name] * (1 - self._masks[name])
                #self._U[name] = self._U[name] * (1 - self._masks[name])
                param_state['residue_buffer'].mul_(self._masks[name])
                param_state['momentum_buffer'].mul_(self._masks[name])
                #self._compressed_msg_size[name] = len(compressed_idx)

                torch.cuda.synchronize()
                begin_pack_time = time.time()

                #if self._use_gpu:
                #    #compressed_msg = torch.cat([\
                #    #        torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'),\
                #    #        compressed_idx.type('torch.cuda.FloatTensor'), \
                #    #        compressed_val])
                #    compressed_msg = torch.cat([\
                #            torch.tensor([len(compressed_idx)]).type('torch.cuda.LongTensor'), \
                #            compressed_idx])

                handle = _allgather_async(compressed_idx,
                                          self._compressed_idx[name],
                                          name=name + 'idx')
                #compressed_msg = torch.randn(100).cuda()
                self._handles[p] = handle

                handle = _allgather_async(compressed_val,
                                          self._compressed_val[name],
                                          name=name + 'val')
                #compressed_msg = torch.randn(100).cuda()
                self._handles_val[p] = handle

                handle = _allgather_async(torch.tensor([len(compressed_idx)]),
                                          self._compressed_len[name],
                                          name=name + 'len')
                #handle = _allgather_async(len(compressed_idx), self._compressed_len[name], name=name + 'len')
                #compressed_msg = torch.randn(100).cuda()
                self._handles_len[p] = handle

                torch.cuda.synchronize()
                self.pack_time += time.time() - begin_pack_time

            else:
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                if self._use_nesterov:
                    self._U[name] = torch.mul(
                        torch.add(self._U[name], p.grad.data), self._momentum)
                    self._V[name] = self._V[name] + self._U[name] + p.grad.data
                else:
                    self._U[
                        name] = self._momentum * self._U[name] + p.grad.data
                    self._V[name] = self._V[name] + self._U[name]
                p.grad.data = self._V[name]
                #compressed_msg = torch.randn(100).cuda()
                #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name)
                handle = allreduce_async_(p.grad.data, average=True, name=name)
                self._handles[p] = handle
            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time
Esempio n. 5
0
        def hook(*ignore):
            assert p not in self._handles
            assert not p.grad.requires_grad
            name = self._parameter_names.get(p)
            p_size = np.prod(p.size())
            torch.cuda.synchronize()
            begin_time = time.time()
            if self._use_allgather and p_size > 1024:
                # fjr compress grad
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                p.grad.data.div_(hvd.size())
                if self._use_nesterov:
                    self._U[name] = torch.mul(
                        torch.add(self._U[name], p.grad.data), self._momentum)
                    self._V[name] = self._V[name] + self._U[name] + p.grad.data
                else:
                    self._U[
                        name] = self._momentum * self._U[name] + p.grad.data
                    self._V[name] = self._V[name] + self._U[name]
                compressed_val = []
                compressed_idx = []
                #if p_size < 1000:
                #    self._masks[name], compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001, self._masks[name])
                #else:
                #self._masks[name], compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001, self._masks[name])
                #self._masks[name], compressed_val, compressed_idx = select_top_k_thd(self._V[name], 0.001, self._masks[name])

                torch.cuda.synchronize()
                begin_select_time = time.time()
                local_mean, compressed_idx = select_top_k_thd_mean(
                    self._V[name], 0.001)
                torch.cuda.synchronize()
                end_select_time = time.time()
                self.select_time += end_select_time - begin_select_time

                #tmp_t = torch.tensor([local_len], dtype=torch.long)
                #                tmp_t = torch.tensor([local_len])
                # print("len list, ", global_len_list)
                #local_len = torch.min(global_len_list)
                ##print("local_len, ", local_len)
                #compressed_val = compressed_val[0:local_len]
                #compressed_idx = compressed_idx[0:local_len]
                masks_size = self._masks[name].size()
                self._masks[name].zero_()
                self._masks[name] = self._masks[name].view(-1)
                self._masks[name][compressed_idx] = 1.0

                self._masks[name] = 1.0 - self._masks[name]
                self._masks[name] = self._masks[name].view(masks_size)

                if self._debug:
                    self._v_ref[name] = self._V[name] * (1.0 -
                                                         self._masks[name])
                    allreduce_(self._v_ref[name], average=False)

                #self._V[name] = self._V[name] * (1 - self._masks[name])
                #self._U[name] = self._U[name] * (1 - self._masks[name])
                self._V[name].mul_(self._masks[name])
                self._U[name].mul_(self._masks[name])
                #self._compressed_msg_size[name] = len(compressed_idx)
                if self._use_gpu:
                    compressed_msg = torch.cat(\
                            [torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'), \
                            torch.tensor([local_mean]).type('torch.cuda.FloatTensor'), \
                            compressed_idx.type('torch.cuda.FloatTensor')])
                else:
                    pass

                handle = _allgather_async(compressed_msg,
                                          self._compressed_msg[name],
                                          name=name)
                #compressed_msg = torch.randn(100).cuda()
                self._handles[p] = handle

            else:
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                if self._use_nesterov:
                    self._U[name] = torch.mul(
                        torch.add(self._U[name], p.grad.data), self._momentum)
                    self._V[name] = self._V[name] + self._U[name] + p.grad.data
                else:
                    self._U[
                        name] = self._momentum * self._U[name] + p.grad.data
                    self._V[name] = self._V[name] + self._U[name]
                p.grad.data = self._V[name]
                #compressed_msg = torch.randn(100).cuda()
                #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name)
                handle = allreduce_async_(p.grad.data, average=True, name=name)
                self._handles[p] = handle
            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time

            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time
Esempio n. 6
0
    def step(self, closure=None):
        # local clipping
        # DGC
        for group in self.param_groups:
            for p in group['params']:
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                p.grad.data.div_(hvd.size())

            torch.nn.utils.clip_grad_norm_(group['params'], 0.25 * hvd.size() ** -0.5)
            #torch.nn.utils.clip_grad_norm(group['params'], 0.25)
            #weight_decay = group['weight_decay']
            #momentum = group['momentum']
            torch.cuda.synchronize()
            begin_time =  time.time()

            dampening = 0.0 #gcoup['dampening']
            for p in group['params']:
                assert p not in self._handles
                assert not p.grad.requires_grad
                name = self._parameter_names.get(p)
                p_size = np.prod(p.size())
                if self._use_allgather and p_size > 1024:
                    param_state = self.state[p]
                    # fjr compress grad
                    if self._use_nesterov:
                        self._U[name] = torch.mul(torch.add(self._U[name], p.grad.data), self._momentum)
                        self._V[name] = self._V[name] + self._U[name] + p.grad.data
                    else:
                        self._U[name] = self._momentum * self._U[name] + p.grad.data
                        self._V[name] = self._V[name] + self._U[name]
                    compressed_val = []
                    compressed_idx = []

                    torch.cuda.synchronize()
                    begin_select_time =  time.time()
                    #if 'interval' not in param_state:
                    #    param_state['interval'] = 1
                    #if param_state['interval'] == 0:
                    #    compressed_val, compressed_idx, _, _, _ = \
                    #        select_bs_top(self._V[name], 0.001)
                    #    param_state['interval'] = 1
                    #else:
                    #    compressed_val, compressed_idx, _, _, _ = \
                    #        select_bs_bottom(self._V[name], 0.001)
                    #    param_state['interval'] = 0

                    compressed_val_top, compressed_idx_top, _, _, _ = \
                        select_bs_top(self._V[name], 0.001)
                    compressed_val_low, compressed_idx_low, _, _, _ = \
                        select_bs_bottom(self._V[name], 0.001)
                    compressed_mean = 0.0
                    if torch.mean(compressed_val_top) > -torch.mean(compressed_val_low):
                        compressed_val = compressed_val_top
                        compressed_idx = compressed_idx_top
                    else:
                        compressed_val = compressed_val_low
                        compressed_idx = compressed_idx_low


                    masks_size = self._masks[name].size()
                    self._masks[name].zero_()
                    self._masks[name] = self._masks[name].view(-1)
                    self._masks[name][compressed_idx] = 1.0
                    self._masks[name] = 1.0 - self._masks[name]
                    self._masks[name] = self._masks[name].view(masks_size)
                    torch.cuda.synchronize()
                    end_select_time =  time.time()
                    self.select_time += end_select_time - begin_select_time

                    if self._debug:
                        self._v_ref[name] = self._V[name] * (1.0 - self._masks[name])
                        allreduce_(self._v_ref[name], average = False)

                    #self._V[name] = self._V[name] * (1 - self._masks[name])
                    #self._U[name] = self._U[name] * (1 - self._masks[name])
                    self._V[name].mul_(self._masks[name])
                    self._U[name].mul_(self._masks[name])

                    torch.cuda.synchronize()
                    begin_comm_time =  time.time()

                    self._compressed_msg_size[name] = len(compressed_idx)
                    if self._use_gpu:
                        compressed_msg = torch.cat([\
                            torch.tensor([len(compressed_idx)]).type('torch.cuda.LongTensor'),\
                            compressed_idx])

                    handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name+"idx")
                    self._handles[p] = handle
                    handle = _allgather_async(torch.mean(compressed_val), \
                            self._compressed_val[name], name=name+"val")
                    self._handles_val[p] = handle

                    torch.cuda.synchronize()
                    self.comm_time += time.time() - begin_comm_time

                else:
                    #p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                    if self._use_nesterov:
                        self._U[name] = torch.mul(torch.add(self._U[name], p.grad.data), self._momentum)
                        self._V[name] = self._V[name] + self._U[name] + p.grad.data
                    else:
                        self._U[name] = self._momentum * self._U[name] + p.grad.data
                        self._V[name] = self._V[name] + self._U[name]
                    p.grad.data = self._V[name]
                    #compressed_msg = torch.randn(100).cuda()
                    handle = allreduce_async_(p.grad.data, average=True, name=name)
                    self._handles[p] = handle

            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time

        self.synchronize()
        return super(self.__class__, self).step(closure)
Esempio n. 7
0
        def hook(*ignore):
            assert p not in self._handles
            assert not p.grad.requires_grad
            name = self._parameter_names.get(p)
            p_size = np.prod(p.size())
            torch.cuda.synchronize()
            begin_time =  time.time()

            if self._use_allgather and p_size > 1024:
                weight_decay = self._weight_decay #group['weight_decay']
                momentum = self._momentum #group['momentum']
                dampening = 0.0 #group['dampening']
                nesterov = False #group['nesterov']
                d_p = p.grad.data
                d_p.div_(hvd.size())
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p)
                        #buf.mul_(momentum).add_(1 - dampening, d_p)
                    #TODO
                    # if nesterov:
                    #     d_p = d_p.add(momentum, buf)
                    # else:
                    #     d_p = buf
                if 'residue_buffer' not in param_state:
                    rsd = param_state['residue_buffer'] = torch.zeros_like(p.data)
                    rsd.add_(param_state['momentum_buffer'])

                    if self._use_nesterov:
                        rsd  = rsd.add(momentum, d_p)
                else:
                    rsd = param_state['residue_buffer']
                    rsd.add_(param_state['momentum_buffer'])
                    if self._use_nesterov:
                        rsd  = rsd.add(momentum, d_p)

                compressed_val = []
                compressed_idx = []

                torch.cuda.synchronize()
                begin_select_time =  time.time()
                if 'interval' not in param_state:
                    param_state['interval'] = 1
                it = 0
                sparsity = 0.0
                if param_state['interval'] == 1:
                    compressed_val, compressed_idx, it, _, sparsity = \
                            select_bs_top(param_state['residue_buffer'], 0.001)
                    param_state['interval'] = 0
                else:
                    compressed_val, compressed_idx, it, _, sparsity = \
                            select_bs_bottom(param_state['residue_buffer'], 0.001)
                    param_state['interval'] = 1
                assert(len(compressed_idx) > 0)
                torch.cuda.synchronize()
                end_select_time =  time.time()
                self.select_time += end_select_time - begin_select_time

                masks_size = self._masks[name].size()
                self._masks[name].zero_()
                self._masks[name] = self._masks[name].view(-1)
                self._masks[name][compressed_idx] = 1.0

                self._masks[name] = 1.0 - self._masks[name]
                self._masks[name] = self._masks[name].view(masks_size)

                if self._debug:
                    self._v_ref[name] = torch.mean(compressed_val) \
                            * (1.0 - self._masks[name])
                    allreduce_(self._v_ref[name], average = False)

                if hvd.size() == 1:
                    p.grad.data = torch.mean(compressed_val) \
                            * (1.0 - self._masks[name])

                param_state['residue_buffer'].mul_(self._masks[name])
                param_state['momentum_buffer'].mul_(self._masks[name])

                torch.cuda.synchronize()
                begin_pack_time =  time.time()
                compressed_msg = []

                if hvd.size() > 1:
                    if self._use_gpu:
                        compressed_msg= torch.cat((\
                                torch.tensor([len(compressed_idx)]).type(torch.cuda.LongTensor),\
                                compressed_idx))

                    handle = _allgather_async(compressed_msg, self._compressed_idx[name], name=name + "idx")
                    self._handles[p] = handle

                    handle = _allgather_async(torch.mean(compressed_val), self._compressed_val[name], name=name + "val")
                    self._handles_val[p] = handle

                torch.cuda.synchronize()
                self.pack_time += time.time() - begin_pack_time

            else:
                weight_decay = self._weight_decay #group['weight_decay']
                momentum = self._momentum #group['momentum']
                dampening = 0.0 #group['dampening']
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if self._use_nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf
                #compressed_msg = torch.randn(100).cuda()
                #handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name)
                if hvd.size() > 1:
                    handle = allreduce_async_(p.grad.data, average=True, name=name)
                    self._handles[p] = handle
            torch.cuda.synchronize()
            end_time = time.time()
            self.pruning_time += end_time - begin_time
Esempio n. 8
0
    def step(self, closure=None):
        # local clipping
        # DGC
        for group in self.param_groups:
            for p in group['params']:
                p.grad.data.add_(torch.mul(p.data, self._weight_decay))
                p.grad.data.div_(hvd.size())

            torch.nn.utils.clip_grad_norm_(group['params'], 0.25 * hvd.size() ** -0.5)
            #torch.nn.utils.clip_grad_norm(group['params'], 0.25)
            #weight_decay = group['weight_decay']
            #momentum = group['momentum']
            torch.cuda.synchronize()
            begin_time =  time.time()

            dampening = 0.0 #gcoup['dampening']
            for p in group['params']:
                assert p not in self._handles
                assert not p.grad.requires_grad
                name = self._parameter_names.get(p)
                p_size = np.prod(p.size())
                if self._use_allgather and p_size > 1024:
                    param_state = self.state[p]
                    self._V[name].add_(p.grad.data)
                    compressed_val = []
                    compressed_idx = []
                    #if p_size < 1000:
                    #self._masks[name], compressed_val, compressed_idx = select_top_k_appr(self._V[name], 0.001, self._masks[name])

                    torch.cuda.synchronize()
                    begin_select_time =  time.time()
                    if 'mid_store' not in param_state:
                        param_state['mid_store'] = 0.0
                    if 'interval' not in param_state:
                        param_state['interval'] = self._interval
                    compressed_val = []
                    compressed_idx = []
                    if param_state['interval'] == self._interval:
                        compressed_val, compressed_idx, it, param_state['mid_store'], sparsity = \
                            select_top_k_thdv3(self._V[name], 0.001)
                        param_state['interval'] = 0
                    else:
                        compressed_val, compressed_idx, sparsity = \
                            select_top_k_fixthd(self._V[name], param_state['mid_store'])
                        param_state['interval'] += 1
                    #masks_size = self._masks[name].size()
                    #self._masks[name].zero_()
                    #self._masks[name] = self._masks[name].view(-1)
                    #self._masks[name][compressed_idx] = 1.0
                    #self._masks[name] = 1.0 - self._masks[name]
                    #self._masks[name] = self._masks[name].view(masks_size)
                    torch.cuda.synchronize()
                    self.select_time += time.time() - begin_select_time

                    if self._debug:
                        self._v_ref[name] = self._V[name] * self._masks[name]
                        allreduce_(self._v_ref[name], average = False)

                    #self._V[name] = self._V[name] * (1 - self._masks[name])
                    #self._U[name] = self._U[name] * (1 - self._masks[name])
                    torch.cuda.synchronize()
                    begin_mask_time =  time.time()
                    V_size = self._masks[name].size()
                    self._V[name] = self._V[name].view(-1)
                    self._V[name][compressed_idx] = 0.0
                    self._V[name] = self._V[name].view(V_size)

                    torch.cuda.synchronize()
                    self.mask_time += time.time() - begin_mask_time
                    begin_pack_time =  time.time()

                    self._compressed_msg_size[name] = len(compressed_idx)
                    if self._use_gpu:
                        compressed_msg = torch.cat([\
                            torch.tensor([len(compressed_idx)]).type('torch.cuda.FloatTensor'),\
                            compressed_idx.type('torch.cuda.FloatTensor'), \
                            compressed_val])

                    handle = _allgather_async(compressed_msg, self._compressed_msg[name], name=name)
                    self._handles[p] = handle

                    torch.cuda.synchronize()
                    self.pack_time += time.time() - begin_pack_time

                else:
                    handle = allreduce_async_(p.grad.data, average=True, name=name)
                    self._handles[p] = handle

            torch.cuda.synchronize()
            self.pruning_time += time.time() - begin_time

        self.synchronize()
        return super(self.__class__, self).step(closure)