示例#1
0
 def __getitem__(self, index):
     ori_img = get_img(self.img_list[index])
     img = resize_image(ori_img, self.config['base']['algorithm'],
                        self.test_size, self.config['testload']['stride'])
     img = Image.fromarray(img).convert('RGB')
     img = self.TSM.normalize_img(img)
     return img, ori_img
示例#2
0
    def infer_img(self,ori_img):
        img = resize_image(ori_img,self.congig['base']['algorithm'],self.congig['testload']['test_size'],stride=self.congig['testload']['stride'])
        img = Image.fromarray(img).convert('RGB')
        img = transforms.ToTensor()(img)
        img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0)
        if torch.cuda.is_available():
            img = img.cuda()
            
        with torch.no_grad():
            out = self.model(img)

        if (config['base']['algorithm'] == 'SAST'):
            scale = ((out['f_score'].shape[2]*4)/ori_img.shape[0],(out['f_score'].shape[3]*4)/ori_img.shape[1] ,ori_img.shape[0],ori_img.shape[1])
        else:
            scale = (ori_img.shape[1] * 1.0 / out.shape[3], ori_img.shape[0] * 1.0 / out.shape[2])
        out = create_process_obj(self.congig['base']['algorithm'],out)
        bbox_batch, score_batch = self.img_process(out, [scale])
        return bbox_batch,score_batch
示例#3
0
    def infer_img(self, ori_img):
        img = resize_image(ori_img, self.congig['base']['algorithm'], 32)
        img = Image.fromarray(img).convert('RGB')
        if (self.congig['base']['is_gray']):
            img = img.convert('L')
        img = transforms.ToTensor()(img)
        img.sub_(0.5).div_(0.5)
        img = img.unsqueeze(0)
        if torch.cuda.is_available():
            img = img.cuda()

        with torch.no_grad():
            preds = self.model(img)
        preds_size = torch.IntTensor([preds.size(0)])
        _, preds = preds.max(2)
        preds = preds.squeeze(1)
        preds = preds.contiguous().view(-1)
        sim_preds = self.converter.decode(preds.data,
                                          preds_size.data,
                                          raw=False)

        return sim_preds
def prune(args):

    stream = open(args.config, 'r', encoding='utf-8')
    config = yaml.load(stream, Loader=yaml.FullLoader)

    img = cv2.imread(args.img_file)
    img = resize_image(img,
                       config['base']['algorithm'],
                       config['testload']['test_size'],
                       stride=config['testload']['stride'])
    img = Image.fromarray(img)
    img = img.convert('RGB')
    img = transforms.ToTensor()(img)
    img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])(img)
    img = Variable(img.cuda()).unsqueeze(0)

    model = create_module(
        config['architectures']['model_function'])(config).cuda()
    model = load_model(model, args.checkpoint)

    model.eval()
    print(model)

    cut_percent = 0.5
    base_num = 4

    bn_weights = []
    for m in model.modules():
        if (isinstance(m, nn.BatchNorm2d)):
            bn_weights.append(m.weight.data.abs().clone())
    bn_weights = torch.cat(bn_weights, 0)

    sort_result, sort_index = torch.sort(bn_weights)

    thresh_index = int(cut_percent * bn_weights.shape[0])

    if (thresh_index == bn_weights.shape[0]):
        thresh_index = bn_weights.shape[0] - 1

    prued = 0
    prued_mask = []
    bn_index = []
    conv_index = []
    remain_channel_nums = []
    tag = 0
    for k, m in enumerate(model.modules()):
        if (tag > 187):
            break
        tag += 1
        if (isinstance(m, nn.BatchNorm2d)):
            bn_weight = m.weight.data.clone()
            mask = bn_weight.abs().gt(sort_result[thresh_index])
            remain_channel = mask.sum()

            if (remain_channel == 0):
                remain_channel = 1
                mask[int(torch.argmax(bn_weight))] = 1

            v = 0
            n = 1
            if (remain_channel % base_num != 0):
                if (remain_channel > base_num):
                    while (v < remain_channel):
                        n += 1
                        v = base_num * n
                    if (remain_channel - (v - base_num) < v - remain_channel):
                        remain_channel = v - base_num
                    else:
                        remain_channel = v
                    if (remain_channel > bn_weight.size()[0]):
                        remain_channel = bn_weight.size()[0]
                    remain_channel = torch.tensor(remain_channel)
                    result, index = torch.sort(bn_weight)
                    mask = bn_weight.abs().ge(result[-remain_channel])

            remain_channel_nums.append(int(mask.sum()))
            prued_mask.append(mask)
            bn_index.append(k)
            prued += mask.shape[0] - mask.sum()
        elif (isinstance(m, nn.Conv2d)):
            conv_index.append(k)

    print('remain_channel_nums', remain_channel_nums)
    print('total_prune_ratio:', float(prued) / bn_weights.shape[0])
    print('bn_index', bn_index)
    print('conv_index', conv_index)

    new_model = create_module(
        config['architectures']['model_function'])(config).cuda()

    keys = {}
    tag = 0
    for k, m in enumerate(new_model.modules()):
        if (isinstance(m, ptocr.model.backbone.det_mobilev3.Block)):
            keys[tag] = k
            tag += 1
    print(keys)
    #### step 1
    mg_1 = np.array([-3, 7, 16])
    block_idx = keys[0]
    tag = 0
    for idx in mg_1 + block_idx:
        if (tag == 0):
            msk = prued_mask[bn_index.index(idx)]
        else:
            msk = msk | prued_mask[bn_index.index(idx)]
        tag += 1
        print('step1', idx)
    print(msk.sum())
    for idx in mg_1 + block_idx:
        prued_mask[bn_index.index(idx)] = msk
    msk_1 = msk.clone()

    #### step 2
    block_idx2 = np.array([keys[1], keys[2]])
    mg_2 = 7
    tag = 0
    for idx in mg_2 + block_idx2:
        print('step2', idx)
        if (tag == 0):
            msk = prued_mask[bn_index.index(idx)]
        else:
            msk = msk | prued_mask[bn_index.index(idx)]
        tag += 1
    for idx in mg_2 + block_idx2:
        prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_2 = msk.clone()

    ####step 3
    block_idx3s = [keys[3], keys[4], keys[5]]
    mg_3 = np.array([7, 16])
    tag = 0
    for block_idx3 in block_idx3s:
        for idx in block_idx3 + mg_3:
            print('step3', idx)
            if (tag == 0):
                msk = prued_mask[bn_index.index(idx)]
            else:
                msk = msk | prued_mask[bn_index.index(idx)]
            tag += 1
    for block_idx3 in block_idx3s:
        for idx in block_idx3 + mg_3:
            prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_3 = msk.clone()

    ####step 4_1
    block_idx4_all = []

    block_idx4 = keys[6]

    mg_4 = np.array([7, 16])
    block_idx4_all.extend((block_idx4 + mg_4).tolist())

    ####step 4_2
    block_idx4 = keys[7]
    mg_4 = np.array([7, 16])
    block_idx4_all.extend((block_idx4 + mg_4).tolist())
    tag = 0

    for idx in block_idx4_all:
        print('step4', idx)
        if (tag == 0):
            msk = prued_mask[bn_index.index(idx)]
        else:
            msk = msk | prued_mask[bn_index.index(idx)]
        tag += 1

    for idx in block_idx4_all:
        prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_4 = msk.clone()

    ####step 5
    block_idx5s = [keys[8], keys[9], keys[10]]
    mg_5 = np.array([7, 16])
    tag = 0
    for block_idx5 in block_idx5s:
        for idx in block_idx5 + mg_5:
            if (tag == 0):
                msk = prued_mask[bn_index.index(idx)]
            else:
                msk = msk | prued_mask[bn_index.index(idx)]
            tag += 1

    for block_idx5 in block_idx5s:
        for idx in block_idx5 + mg_5:
            prued_mask[bn_index.index(idx)] = msk
    print(msk.sum())
    msk_5 = msk.clone()

    group_index = []
    spl_index = []
    for i in range(11):
        block_idx6 = keys[i]
        tag = 0
        mg_6 = np.array([2, 5])
        for idx in mg_6 + block_idx6:
            if (tag == 0):
                msk = prued_mask[bn_index.index(idx)]
            else:
                msk = msk | prued_mask[bn_index.index(idx)]
            tag += 1
        for idx in mg_6 + block_idx6:
            prued_mask[bn_index.index(idx)] = msk
        if (i == 6):
            spl_index.extend([block_idx6 + 9, block_idx6 - 2])
        group_index.append(block_idx6 + 4)
    import pdb
    pdb.set_trace()
    count_conv = 0
    count_bn = 0
    conv_in_mask = [torch.ones(3)]
    conv_out_mask = []
    bn_mask = []
    tag = 0
    for k, m in enumerate(new_model.modules()):
        if (tag > 187):
            break
        if isinstance(m, nn.Conv2d):

            if (tag in group_index):
                m.groups = int(prued_mask[bn_index.index(tag + 1)].sum())
            m.out_channels = int(prued_mask[count_conv].sum())
            conv_out_mask.append(prued_mask[count_conv])
            if (count_conv > 0):
                if (tag == spl_index[0]):
                    m.in_channels = int(prued_mask[bn_index.index(
                        spl_index[1])].sum())
                    conv_in_mask.append(prued_mask[bn_index.index(
                        spl_index[1])])
                else:
                    m.in_channels = int(prued_mask[count_conv - 1].sum())
                    conv_in_mask.append(prued_mask[count_conv - 1])

            count_conv += 1
        elif isinstance(m, nn.BatchNorm2d):
            m.num_features = prued_mask[count_bn].sum()
            bn_mask.append(prued_mask[count_bn])
            count_bn += 1
        tag += 1

    bn_i = 0
    conv_i = 0
    model_i = 0
    scale = [188, 192, 196, 200]
    scale_mask = [msk_5, msk_4, msk_3, msk_2]
    for [m0, m1] in zip(model.modules(), new_model.modules()):
        if (model_i > 187):
            if isinstance(m0, nn.Conv2d):
                if (model_i in scale):
                    index = scale.index(model_i)
                    m1.in_channels = int(scale_mask[index].sum())
                    idx0 = np.squeeze(
                        np.argwhere(np.asarray(
                            scale_mask[index].cpu().numpy())))
                    idx1 = np.squeeze(
                        np.argwhere(np.asarray(torch.ones(96).cpu().numpy())))
                    if idx0.size == 1:
                        idx0 = np.resize(idx0, (1, ))
                    if idx1.size == 1:
                        idx1 = np.resize(idx1, (1, ))
                    w = m0.weight.data[:, idx0, :, :].clone()
                    m1.weight.data = w[idx1, :, :, :].clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.data[idx1].clone()

                else:
                    m1.weight.data = m0.weight.data.clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.data.clone()

            elif isinstance(m0, nn.BatchNorm2d):
                m1.weight.data = m0.weight.data.clone()
                if m1.bias is not None:
                    m1.bias.data = m0.bias.data.clone()
                m1.running_mean = m0.running_mean.clone()
                m1.running_var = m0.running_var.clone()
        else:
            if isinstance(m0, nn.BatchNorm2d):
                idx1 = np.squeeze(
                    np.argwhere(np.asarray(bn_mask[bn_i].cpu().numpy())))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1, ))
                m1.weight.data = m0.weight.data[idx1].clone()
                if m1.bias is not None:
                    m1.bias.data = m0.bias.data[idx1].clone()
                m1.running_mean = m0.running_mean[idx1].clone()
                m1.running_var = m0.running_var[idx1].clone()
                bn_i += 1
            elif isinstance(m0, nn.Conv2d):
                if (isinstance(conv_in_mask[conv_i], list)):
                    idx0 = np.squeeze(
                        np.argwhere(
                            np.asarray(
                                torch.cat(conv_in_mask[conv_i],
                                          0).cpu().numpy())))
                else:
                    idx0 = np.squeeze(
                        np.argwhere(
                            np.asarray(conv_in_mask[conv_i].cpu().numpy())))
                idx1 = np.squeeze(
                    np.argwhere(np.asarray(
                        conv_out_mask[conv_i].cpu().numpy())))
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1, ))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1, ))
                if (model_i in group_index):
                    m1.weight.data = m0.weight.data[idx1, :, :, :].clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.clone()
                else:
                    w = m0.weight.data[:, idx0, :, :].clone()
                    m1.weight.data = w[idx1, :, :, :].clone()
                    if m1.bias is not None:
                        m1.bias.data = m0.bias.data[idx1].clone()
                conv_i += 1
        model_i += 1

    print(new_model)
    new_model.eval()
    with torch.no_grad():
        out = new_model(img)
    print(out.shape)
    cv2.imwrite('re1.jpg', out[0, 0].cpu().numpy() * 255)

    save_obj = {'prued_mask': prued_mask, 'bn_index': bn_index}
    torch.save(save_obj,
               os.path.join(args.save_prune_model_path, 'pruned_dict.dict'))
    torch.save(new_model.state_dict(),
               os.path.join(args.save_prune_model_path, 'pruned_dict.pth.tar'))