Ejemplo n.º 1
0
def get_prune_list(resnet_channel_list, prob_list, dcfg, expand_rate=0.0001):
    import numpy as np
    prune_list = []
    idx = 0

    chn_input_full, chn_output_full = 3, resnet_channel_list[0]
    dnas_conv = lambda input, output: DNAS.Conv2d(input, output, 1, 1, 1, False, dcfg=dcfg)
    conv = dnas_conv(chn_input_full, chn_output_full)
    chn_output_prune = int(np.round(
        min(torch.dot(prob_list[idx], conv.out_plane_list).item(), chn_output_full)
    ))
    chn_output_prune += int(np.ceil(expand_rate*(chn_output_full-chn_output_prune)))
    prune_list.append(chn_output_prune)
    chn_input_full = chn_output_full
    idx += 1
    for blocks in resnet_channel_list[1:]:
        blocks_list = []
        for block in blocks:
            block_prune_list = []
            for chn_output_full in block:
                conv = DNAS.Conv2d(chn_input_full, chn_output_full, 1, 1, 1, False, dcfg=dcfg)
                print(prob_list[idx], conv.out_plane_list, torch.dot(prob_list[idx], conv.out_plane_list).item())
                chn_output_prune = int(np.round(
                    min(torch.dot(prob_list[idx], conv.out_plane_list).item(), chn_output_full)
                ))
                chn_output_prune += int(np.ceil(expand_rate*(chn_output_full-chn_output_prune)))
                block_prune_list.append(chn_output_prune)
                chn_input_full = chn_output_full
                idx += 1
            blocks_list.append(block_prune_list)
        prune_list.append(blocks_list)
    return prune_list
Ejemplo n.º 2
0
    def forward(self, x, tau=1, noise=False, reuse_prob=None, rmask=None):
        prob = reuse_prob
        shortcut = x
        x = F.relu(DNAS.weighted_feature(self.bn0(x), rmask))
        prob_list, flops_list = [], []
        if self.shortcut is not None:
            shortcut, rmask, prob, shortcut_flops = self.shortcut(x,
                                                                  tau,
                                                                  noise,
                                                                  p_in=prob)
            prob_list.append(prob)
            flops_list.append(shortcut_flops)

        x, rmask_1, p1, conv1_flops = self.conv1(x, tau, noise, p_in=prob)
        x = F.relu(DNAS.weighted_feature(self.bn1(x), rmask_1))
        prob_list.insert(0, p1)
        flops_list.insert(0, conv1_flops)

        x, rmask_2, p2, conv2_flops = self.conv2(x, tau, noise, p_in=p1)
        x = F.relu(DNAS.weighted_feature(self.bn2(x), rmask_2))
        prob_list.insert(1, p2)
        flops_list.insert(1, conv2_flops)

        x, rmask_3, prob, conv3_flops = self.conv3(x,
                                                   tau,
                                                   noise,
                                                   reuse_prob=prob,
                                                   p_in=p2)
        prob_list.insert(2, prob)
        flops_list.insert(2, conv3_flops)

        x += shortcut
        x = DNAS.weighted_feature(x, rmask)

        return x, rmask_3, prob, prob_list, flops_list
Ejemplo n.º 3
0
    def __init__(self,
                 in_planes,
                 out_planes_list,
                 stride=2,
                 project=False,
                 dcfg=None):
        super(BottleneckGated, self).__init__()
        out_planes_1 = out_planes_list[0]
        out_planes_2 = out_planes_list[1]
        out_planes_3 = out_planes_list[2]
        assert dcfg is not None
        self.dcfg = dcfg
        self.dcfg_nonreuse = dcfg.copy()

        self.bn0 = nn.BatchNorm2d(in_planes,
                                  momentum=_BATCH_NORM_DECAY,
                                  eps=_EPSILON)
        self.conv1 = DNAS.Conv2d(in_planes,
                                 out_planes_1,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 bias=False,
                                 dcfg=self.dcfg_nonreuse)
        self.bn1 = nn.BatchNorm2d(out_planes_1,
                                  momentum=_BATCH_NORM_DECAY,
                                  eps=_EPSILON)
        self.conv2 = DNAS.Conv2d(out_planes_1,
                                 out_planes_2,
                                 kernel_size=3,
                                 stride=stride,
                                 padding=1,
                                 bias=False,
                                 dcfg=self.dcfg_nonreuse)
        self.bn2 = nn.BatchNorm2d(out_planes_2,
                                  momentum=_BATCH_NORM_DECAY,
                                  eps=_EPSILON)
        self.shortcut = None
        # if stride != 1 and len(out_planes_list) > 2:
        if project:
            self.shortcut = DNAS.Conv2d(in_planes,
                                        out_planes_list[-1],
                                        kernel_size=1,
                                        stride=stride,
                                        padding=0,
                                        bias=False,
                                        dcfg=self.dcfg_nonreuse)
            self.dcfg.reuse_gate = self.shortcut.gate
        self.conv3 = DNAS.Conv2d(out_planes_2,
                                 out_planes_3,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 bias=False,
                                 dcfg=self.dcfg)
Ejemplo n.º 4
0
 def __init__(self,
              num_blocks,
              num_planes,
              num_colors=3,
              scale=1,
              res_scale=0.1):
     super(EDSRGated, self).__init__()
     self.num_blocks = num_blocks
     self.act = nn.ReLU(inplace=True)
     self.sub_mean = MeanShift(1)
     self.add_mean = MeanShift(1, sign=1)
     self.conv0 = dnas.Conv2d(num_colors,
                              num_planes,
                              3,
                              stride=1,
                              padding=1,
                              bias=False,
                              dcfg=self.dcfg)
     self.dcfg.reuse_gate = self.conv0.gate
     self.blocks = list()
     for _ in range(num_blocks):
         self.blocks.append(EDSRBlockGated(num_planes, res_scale,
                                           self.dcfg))
     m_tail = list()
     m_tail.append(Upsampler(scale, num_planes))
     m_tail.append(
         nn.Conv2d(in_channels=num_planes,
                   out_channels=num_colors,
                   kernel_size=3,
                   stride=1,
                   padding=1,
                   bias=False))
     self.tail = nn.Sequential(*m_tail)
Ejemplo n.º 5
0
    def __init__(self, n_layer, n_class, channel_lists, dcfg):
        super(ResNet, self).__init__()

        if n_layer not in cfg.keys():
            print('Numer of layers Error: ', n_layer)
            exit(1)
        self.n_class = n_class
        self.channel_lists = channel_lists
        self.block_n_cell = cfg[n_layer]
        self.dcfg = dcfg
        self.base_n_channel = channel_lists[0]
        self.imagenet = len(self.block_n_cell) > 3

        if self.imagenet:
            self.cell_fn = BottleneckGated if n_layer >= 50 else ResidualBlockGated
            self.conv0 = DNAS.Conv2d(3,
                                     self.base_n_channel,
                                     7,
                                     stride=2,
                                     padding=3,
                                     bias=False,
                                     dcfg=self.dcfg)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.cell_fn = ResidualBlockGated
            self.conv0 = DNAS.Conv2d(3,
                                     self.base_n_channel,
                                     3,
                                     stride=1,
                                     padding=1,
                                     bias=False,
                                     dcfg=self.dcfg)
        self.dcfg.reuse_gate = self.conv0.gate

        self.block_list = self._block_layers()
        self.bn = nn.BatchNorm2d(channel_lists[-1][-1][-1],
                                 momentum=_BATCH_NORM_DECAY,
                                 eps=_EPSILON)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = DNAS.Linear(channel_lists[-1][-1][-1],
                              self.n_class,
                              dcfg=self.dcfg)
        self.apply(_weights_init)
Ejemplo n.º 6
0
 def train_prune(self, tau, n_epoch=250, load_path, save_path):
     self.load_model(load_path)
     dcfg = DNAS.DcpConfig(n_param=8,
                           split_type=dnas.TYPE_A,
                           reuse_gate=None)
     data = next(iter(self.train_loader))
     in_data, im_gt = [x.to(self.device) for x in data]
     self.net.eval()
     outputs, prob_list, flops, flops_list = self.forward(in_data,
                                                          tau=tau,
                                                          noise=False)
Ejemplo n.º 7
0
 def __init__(self, num_planes, res_scale=1, dcfg=None):
     super(EDSRBlockGated, self).__init__()
     assert dcfg is not None
     self.dcfg = dcfg
     self.dcfg_nonreuse = dcfg.copy()
     self.conv1 = dnas.Conv2d(num_planes,
                              num_planes,
                              kernel_size=3,
                              stride=1,
                              padding=1,
                              bias=False,
                              dcfg=self.dcfg_nonreuse)
     self.act = nn.ReLU(inplace=True)
     self.conv2 = dnas.Conv2d(num_planes,
                              num_planes,
                              kernel_size=3,
                              stride=1,
                              padding=1,
                              bias=False,
                              dcfg=self.dcfg)
     self.res_scale = res_scale
Ejemplo n.º 8
0
    def forward(self, x, tau=1, noise=False, reuse_prob=None, rmask=None):
        prob = reuse_prob
        shortcut = x
        # x = self.bn0(x), x = DNAS.weighted_feature(x, rmask), x = F.relu(x)
        x = F.relu(DNAS.weighted_feature(self.bn0(x), rmask))
        prob_list, flops_list = [], []
        if self.shortcut is not None:
            # todo: original implementation
            # shortcut, prob, shortcut_flops = self.shortcut(x, tau, noise, p_in=prob)
            shortcut, rmask, prob, shortcut_flops = self.shortcut(x,
                                                                  tau,
                                                                  noise,
                                                                  p_in=prob)
            prob_list.append(prob)
            flops_list.append(shortcut_flops)
        # todo: original implementation
        # x, p0, conv1_flops = self.conv1(x, tau, noise, p_in=prob)
        x, rmask_1, p1, conv1_flops = self.conv1(x, tau, noise, p_in=prob)
        # x = self.bn1(x), x = DNAS.weighted_feature(x, rmask_1), x = F.relu(x)
        x = F.relu(DNAS.weighted_feature(self.bn1(x), rmask_1))
        prob_list.insert(0, p1)
        flops_list.insert(0, conv1_flops)

        # todo: original implementation
        # x, prob, conv2_flops = self.conv2(x, tau, noise, reuse_prob=prob, p_in=p0)
        x, rmask_2, prob, conv2_flops = self.conv2(x,
                                                   tau,
                                                   noise,
                                                   reuse_prob=prob,
                                                   p_in=p1)
        prob_list.insert(1, prob)
        flops_list.insert(1, conv2_flops)

        x += shortcut
        x = DNAS.weighted_feature(x, rmask)
        # todo: the order of prob and flops should correspond to the order of channels
        # todo: original implementation
        # return x, prob, prob_list, flops_list
        return x, rmask_2, prob, prob_list, flops_list
Ejemplo n.º 9
0
    def train_prune(self, tau, n_epoch=250,
                    load_path='./models/search/model.pth',
                    save_path='./models/prune/model.pth'):
        # Done, 0. load the searched model and extract the prune info
        # Done, 1. define the slim network based on prune info
        # Done, 2. train and validate, and exploit the full learner
        self.load_model(load_path)
        dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None)
        channel_list = ResNetChannelList(self.args.teacher_net_index)

        self.net.eval()
        data = next(iter(self.train_loader))
        inputs, labels = data[0].to(self.device), data[1].to(self.device)
        outputs, prob_list, flops, flops_list = self.forward(inputs, tau=tau, noise=False)
        print('=================')
        print(tau)
        for prob in prob_list:
            for item in prob.tolist():
                print('{0:.2f}'.format(item), end=' ')
            print()
        print('------------------')

        for item in self.forward.named_parameters():
            # if '0.0.bn0' in item[0] and 'bias' not in item[0]:
            #     print(item)
            if 'conv0.conv.weight' in item[0]:
                print(item[1][:, 0, 0, 0])
                print(item[1][:, 1, 2, 1])

        channel_list_prune = get_prune_list(channel_list, prob_list, dcfg=dcfg)
        # channel_list_prune = [16,
        #                       [[10, 12, 12], [7, 12], [11, 12]],
        #                       [[31, 25, 25], [19, 25], [32, 25]],
        #                       [[39, 43, 43], [61, 43], [41, 43]]]

        print(channel_list_prune)
        # exit(1)
        # channel_list_prune = [13, [[9, 13], [9, 13], [16, 13]], [[16, 25, 25], [27, 25], [16, 25]], [[54, 64, 64], [45, 64], [29, 64]]]
        # teacher_net = ResNet20(n_classes=self.dataset.n_class)
        # teacher = Distiller(self.dataset, teacher_net, self.device, self.args, model_path='./models/6884.pth')

        net = ResNetL(self.args.net_index, self.dataset.n_class, channel_list_prune)
        full_learner = FullLearner(self.dataset, net, device=self.device, args=self.args, teacher=self.teacher)
        print('FLOPs:', full_learner.cnt_flops())
        full_learner.train(n_epoch=n_epoch, save_path=save_path)
Ejemplo n.º 10
0
 def forward(self, x, tau=1, noise=False):
     x, rmask, prob, flops = self.conv0(x, tau, noise)
     # todo: original implementation
     # x, prob, flops = self.conv0(x, tau, noise)
     prob_list, flops_list = [prob], [flops]
     if self.imagenet:
         x = self.maxpool(x)
     for i, blocks in enumerate(self.block_list):
         for block in blocks:
             # todo: original implementation
             # x, prob, blk_prob_list, blk_flops_list = block(x, tau, noise, reuse_prob=prob, rmask=rmask)
             x, rmask, prob, blk_prob_list, blk_flops_list = block(
                 x, tau, noise, reuse_prob=prob, rmask=rmask)
             prob_list += blk_prob_list
             flops_list += blk_flops_list
             prob = blk_prob_list[-1]
     x = F.relu(DNAS.weighted_feature(self.bn(x), rmask))
     x = self.avgpool(x)
     x = torch.flatten(x, 1)
     x, fc_flops = self.fc(x, p_in=prob)
     flops_list += [fc_flops]
     return x, prob_list, torch.sum(torch.stack(flops_list)), flops_list
Ejemplo n.º 11
0
def ResNetGated(n_layer, n_class):
    dcfg = DNAS.DcpConfig(n_param=8, split_type=DNAS.TYPE_A, reuse_gate=None)
    channel_list = ResNetChannelList(n_layer)
    return ResNet(n_layer, n_class, channel_list, dcfg)