예제 #1
0
        def __init__(self, args):
            self.args = args[0]
            self.ckp = args[1]
            super(Hinge, self).__init__(self.args)

            # traning or loading phase of network compression
            if self.args.model.lower().find(
                    'hinge') >= 0 and not self.args.test_only:
                self.load(strict=True)

            if self.args.data_train.find('CIFAR') >= 0:
                self.input_dim = (3, 32, 32)
            elif self.args.data_train.find('Tiny_ImageNet') >= 0:
                self.input_dim = (3, 64, 64)
            else:
                self.input_dim = (3, 224, 224)
            self.flops, self.params = get_model_complexity_info(
                self, self.input_dim, print_per_layer_stat=False)

            modify_network(self)

            if self.args.model.lower().find(
                    'hinge'
            ) >= 0 and not self.args.test_only and self.args.layer_balancing:
                # need to calculate the layer-balancing regularizer during both training and loading phase.
                self.layer_reg = self.calc_regularization()
                for l, reg in enumerate(self.layer_reg):
                    print('DenseBlock {:<2}: {:<2.4f}'.format(l + 1, reg))
예제 #2
0
def calc_model_complexity(model):
    model = model.get_model()
    model.flops_compress, model.params_compress = get_model_complexity_info(model, model.input_dim,
                                                                           print_per_layer_stat=False)
    print('FLOPs ratio {:.2f} = {:.4f} [G] / {:.4f} [G]; Parameter ratio {:.2f} = {:.2f} [k] / {:.2f} [k].\n\n'
          .format(model.flops_compress / model.flops * 100, model.flops_compress / 10. ** 9, model.flops / 10. ** 9,
                  model.params_compress / model.params * 100, model.params_compress / 10. ** 3, model.params / 10. ** 3))
예제 #3
0
        def __init__(self, args):
            self.args = args[0]
            self.ckp = args[1]
            super(Hinge, self).__init__(self.args)

            # traning phase of network compression
            if self.args.model.lower().find('hinge') >= 0 and not self.args.test_only and not self.args.load:
                # for (k1, v1), (k2, v2) in zip(self.state_dict().items(), torch.load(self.args.pretrain).items()):
                #     print(k1, v1.shape)
                #     print(k2, v2.shape)
                # embed()
                self.load(self.args, strict=False)

            # self.num_params = dutil.param_count(self, self.args.ignore_linear)
            if self.args.data_train.find('CIFAR') >= 0:
                self.input_dim = (3, 32, 32)
            elif self.args.data_train.find('Tiny_ImageNet') >= 0:
                self.input_dim = (3, 64, 64)
            else:
                self.input_dim = (3, 224, 224)
            # self.s = 224 if self.args.data_train == 'ImageNet' else 32
            self.flops, self.params = get_model_complexity_info(self, self.input_dim, print_per_layer_stat=False)

            # if self.args.model.lower().find('resnet') >= 0:
            self.register_buffer('running_grad_ratio', None)

            modify_network(self)
예제 #4
0
    def __init__(self, args, ckp, converging):
        self.args = args
        self.ckp = ckp
        super(Prune, self).__init__(self.args)

        # traning or loading for searching
        if not self.args.test_only and not converging:
            self.load(self.args, strict=True)

        if self.args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif self.args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        self.flops, self.params = get_model_complexity_info(
            self, self.input_dim, print_per_layer_stat=False)
예제 #5
0
        def __init__(self, args):
            self.args = args[0]
            self.ckp = args[1]
            super(Hinge, self).__init__(self.args)

            # traning phase of network compression
            if self.args.model.lower().find(
                    'hinge'
            ) >= 0 and not self.args.test_only and not self.args.load:
                self.load(self.args, strict=True)

            if self.args.data_train.find('CIFAR') >= 0:
                self.input_dim = (3, 32, 32)
            elif self.args.data_train.find('Tiny_ImageNet') >= 0:
                self.input_dim = (3, 64, 64)
            else:
                self.input_dim = (3, 224, 224)
            self.flops, self.params = get_model_complexity_info(
                self, self.input_dim, print_per_layer_stat=False)

            modify_network(self)
    def __init__(self, args, ckp, converging):
        self.args = args
        self.ckp = ckp
        super(Hinge, self).__init__(self.args)

        # traning or loading for searching
        if not self.args.test_only and not converging:
            self.load(self.args, strict=False)

        # self.num_params = dutil.param_count(self, self.args.ignore_linear)
        if self.args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif self.args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        # self.s = 224 if self.args.data_train == 'ImageNet' else 32
        self.flops, self.params = get_model_complexity_info(self, self.input_dim, print_per_layer_stat=False)

        # if self.args.model.lower().find('resnet') >= 0:
        self.register_buffer('running_grad_ratio', None)

        modify_network(self)
    def __init__(self, args, ckp, converging):
        self.args = args
        self.ckp = ckp
        super(Hinge, self).__init__(self.args)

        # traning or loading for searching
        if not self.args.test_only and not converging:
            self.load(self.args.pretrain, strict=True)
            if self.args.layer_balancing:
                # need to calculate the layer-balancing regularizer during both training and loading phase.
                self.layer_reg = self.calc_regularization()
                for l, reg in enumerate(self.layer_reg):
                    print('Block {:<2}: regularization {:<2.4f}'.format(
                        l + 1, reg))

        if self.args.data_train.find('CIFAR') >= 0:
            self.input_dim = (3, 32, 32)
        elif self.args.data_train.find('Tiny_ImageNet') >= 0:
            self.input_dim = (3, 64, 64)
        else:
            self.input_dim = (3, 224, 224)
        self.flops, self.params = get_model_complexity_info(
            self, self.input_dim, print_per_layer_stat=False)
예제 #8
0
    def algorithm(self, t, P, ratio):
        # 输入:预训练包含滤波器集合F的网络
        # 每组滤波器由几个滤波器组成 :P
        # 输出:剪枝后的包含滤波器集合F'的网络
        # index = []
        current_ratio_list = []
        G_MWG = {}  # {key : layer_num#channel_num#P, value : importance}

        # 计算每组filter的重要性
        modules = [m for m in self.modules()
                   if isinstance(m, BasicBlock)][0:11]

        for layer, module_cur in enumerate(modules):  # 不剪最后一层
            conv11 = module_cur[0]
            ws1 = conv11.weight.shape
            projection1 = conv11.weight.data.view(
                ws1[0], -1).t()  # reshape (input * k * k,output)
            Fl = torch.norm(projection1, p=2, dim=0)  # shape : output
            # FR eg: tensor([0, 2, 3, 1])
            FR = Fl.sort()[1]  # 每个filter在当前层的重要性排序 按二范数升序排列 越小越容易被剪枝
            filter_num = FR.shape[0]
            # print("FR : ")
            # print(FR)
            cluster_num = math.ceil(filter_num / P)
            factors = np.zeros(cluster_num)
            for i in range(filter_num):  # 当前层第i名
                filter_indice = FR[i]  # 对应的index索引
                factors[int(filter_indice / P)] = factors[int(
                    filter_indice / P)] + (Fl[filter_indice] * i) / P
            for cluster in range(cluster_num):
                key = str(layer) + '#' + str(cluster) + '#' + str(P)
                G_MWG[key] = factors[cluster]
            # print("factors : ")
            # print(factors)

        current_ratio = 1.0
        self.flops_compress = self.flops

        while current_ratio > ratio:
            # t.train()
            t.test()
            # 以上是常规操作
            ######################################################
            (layer_num, group_num, key) = findMinGroup(G_MWG)
            del G_MWG[key]  # 删除该元素
            print("prun layer :%d, group : %d" % (layer_num, group_num))
            cur_layer = modules[layer_num]
            conv11 = cur_layer[0]
            ws1 = conv11.weight.shape
            weight1 = conv11.weight.data.view(ws1[0], -1).t()
            pindex1 = torch.ones(weight1.shape[1]).to(weight1.device)
            pindex1[group_num * P:group_num * P + P] = 0
            pindex1 = torch.nonzero(pindex1).squeeze(dim=1)
            self.compress_one_layer(layer_num, -1, pindex1=pindex1)

            # calc_model_complexity(self)

            self.flops_compress, self.params_compress = get_model_complexity_info(
                self, self.input_dim, print_per_layer_stat=False)
            print(
                'FLOPs ratio {:.2f} = {:.4f} [G] / {:.4f} [G]; Parameter ratio {:.2f} = {:.2f} [k] / {:.2f} [k].'
                .format(self.flops_compress / self.flops * 100,
                        self.flops_compress / 10.**9, self.flops / 10.**9,
                        self.params_compress / self.params * 100,
                        self.params_compress / 10.**3, self.params / 10.**3))

            current_ratio = self.flops_compress / self.flops
            t.model.get_model().current_ratio_list.append(
                "{:.4f}".format(current_ratio))
            print("current_ratio_list : ")
            print(t.model.get_model().current_ratio_list)

            t.model.get_model().parameter_ratio_list.append("{:.4f}".format(
                self.params_compress / self.params))
            print("parameter ratio list : ")
            print(t.model.get_model().parameter_ratio_list)

        save(t.model.get_model().current_ratio_list,
             t.model.get_model().timer_test_list,
             t.model.get_model().sum_list,
             t.model.get_model().top1_err_list,
             t.model.get_model().parameter_ratio_list)