예제 #1
0
    def run(self, prune_num):
        all_bn_id, other_idx, shortcut_idx, downsample_idx = obtain_prune_idx_layer(
            self.model)

        bn_mean = torch.zeros(len(shortcut_idx))
        for i, idx in enumerate(shortcut_idx):
            bn_mean[i] = list(self.model.named_modules()
                              )[idx][1].weight.data.abs().mean().clone()
        _, sorted_index_thre = torch.sort(bn_mean)

        prune_shortcuts = torch.tensor(shortcut_idx)[[
            sorted_index_thre[:prune_num]
        ]]
        prune_shortcuts = [int(x) for x in prune_shortcuts]
        print_mean(bn_mean, shortcut_idx, prune_shortcuts)

        prune_layers = []
        for prune_shortcut in prune_shortcuts:
            target_idx = all_bn_id.index(prune_shortcut)
            for i in range(3):
                prune_layers.append(all_bn_id[target_idx - i])

        CBLidx2mask = obtain_layer_filters_mask(self.model, all_bn_id,
                                                prune_layers)

        pruned_locations = self.obtain_block_idx(shortcut_idx, prune_shortcuts)
        blocks = self.block_num
        for pruned_location in pruned_locations:
            blocks[pruned_location] -= 1

        m_cfg = {
            'backbone': self.backbone,
            'kps': self.kps,
            'se_ratio': self.se_ratio,
            "first_conv": self.first_conv,
            'residual': self.residual,
            'channels': obtain_channel_with_block_num(blocks),
            "head_type": self.head_type,
            "head_channel": self.head_channel
        }
        write_cfg(m_cfg, self.compact_model_cfg)
        posenet.build(self.compact_model_cfg)
        compact_model = posenet.model
        # compact_all_bn = [idx for idx, mod in enumerate(list(compact_model.named_modules()))
        #                   if isinstance(mod[1], torch.nn.BatchNorm2d)]
        compact_all_bn_idx, compact_other_idx, compact_shortcut_idx, compact_downsample_idx = \
            obtain_prune_idx_layer(compact_model)
        init_weights_from_loose_model_layer(compact_model, self.model,
                                            CBLidx2mask, compact_all_bn_idx)
        torch.save(compact_model.state_dict(), self.compact_model_path)
예제 #2
0
def pruning(weight,
            compact_model_path,
            compact_model_cfg="cfg.txt",
            thresh=80,
            device="cpu"):
    if opt.backbone == "mobilenet":
        from models.mobilenet.MobilePose import createModel
        from config.model_cfg import mobile_opt as model_ls
    elif opt.backbone == "seresnet101":
        from models.seresnet101.FastPose import createModel
        from config.model_cfg import seresnet_cfg as model_ls
    elif opt.backbone == "seresnet18":
        from models.seresnet18.FastPose import createModel
        from config.model_cfg import seresnet_cfg as model_ls
    elif opt.backbone == "efficientnet":
        from models.efficientnet.EfficientPose import createModel
        from config.model_cfg import efficientnet_cfg as model_ls
    elif opt.backbone == "shufflenet":
        from models.shufflenet.ShufflePose import createModel
        from config.model_cfg import shufflenet_cfg as model_ls
    elif opt.backbone == "seresnet50":
        from models.seresnet50.FastPose import createModel
        from config.model_cfg import seresnet50_cfg as model_ls
    else:
        raise ValueError("Your model name is wrong")

    try:
        model_cfg = model_ls[opt.struct]
        # opt.loadModel = weight

        model = createModel(cfg=model_cfg)
    except:
        model = createModel(cfg=opt.struct)

    model.load_state_dict(torch.load(weight))
    if device == "cpu":
        model.cpu()
    else:
        model.cuda()
    # torch_out = torch.onnx.export(model, torch.rand(1, 3, 224, 224), "onnx_pose.onnx", verbose=False,)

    tmp = "./buffer/model.txt"
    print(model, file=open(tmp, 'w'))
    if opt.backbone == "seresnet18":
        all_bn_id, normal_idx, shortcut_idx, downsample_idx, head_idx = obtain_prune_idx2(
            model)
    elif opt.backbone == "seresnet50" or opt.backbone == "seresnet101":
        all_bn_id, normal_idx, shortcut_idx, downsample_idx, head_idx = obtain_prune_idx_50(
            model)
    else:
        raise ValueError("Not a correct name")
    prune_idx = normal_idx + head_idx
    sorted_bn = sort_bn(model, prune_idx)

    threshold = obtain_bn_threshold(model, sorted_bn, thresh / 100)
    pruned_filters, pruned_maskers = obtain_filters_mask(
        model, prune_idx, threshold)
    CBLidx2mask = {
        idx - 1: mask.astype('float32')
        for idx, mask in zip(all_bn_id, pruned_maskers)
    }
    CBLidx2filter = {
        idx - 1: filter_num
        for idx, filter_num in zip(all_bn_id, pruned_filters)
    }

    for head in head_idx:
        adjust_mask(CBLidx2mask, CBLidx2filter, model, head)

    valid_filter = {
        k: v
        for k, v in CBLidx2filter.items() if k + 1 in prune_idx
    }
    channel_str = ",".join(map(lambda x: str(x), valid_filter.values()))
    print(channel_str, file=open(compact_model_cfg, "w"))
    m_cfg = {
        'backbone':
        opt.backbone,
        'keypoints':
        opt.kps,
        'se_ratio':
        opt.se_ratio,
        "first_conv":
        CBLidx2filter[all_bn_id[0] - 1],
        'residual':
        get_residual_channel([filt for _, filt in valid_filter.items()],
                             opt.backbone),
        'channels':
        get_channel_dict([filt for _, filt in valid_filter.items()],
                         opt.backbone),
        "head_type":
        "pixel_shuffle",
        "head_channel": [CBLidx2filter[i - 1] for i in head_idx]
    }
    write_cfg(m_cfg, "buffer/cfg_{}.json".format(opt.backbone))

    compact_model = createModel(cfg=compact_model_cfg).cpu()
    print(compact_model, file=open("buffer/pruned.txt", 'w'))

    if opt.backbone == "seresnet18":
        init_weights_from_loose_model(compact_model, model, CBLidx2mask,
                                      valid_filter, downsample_idx, head_idx)
    elif opt.backbone == "seresnet50" or opt.backbone == "seresnet101":
        init_weights_from_loose_model50(compact_model, model, CBLidx2mask,
                                        valid_filter, downsample_idx, head_idx)
    torch.save(compact_model.state_dict(), compact_model_path)
    def run(self, threshold):
        all_bn_id, normal_idx, shortcut_idx, downsample_idx, head_idx = self.obtain_prune_idx(
            self.model)
        prune_idx = all_bn_id
        sorted_bn = sort_bn(self.model, prune_idx)

        threshold = obtain_bn_threshold(self.model, sorted_bn, threshold / 100)
        pruned_filters, pruned_maskers = obtain_filters_mask(
            self.model, prune_idx, threshold)
        CBLidx2mask = {
            idx - 1: mask.astype('float32')
            for idx, mask in zip(all_bn_id, pruned_maskers)
        }
        CBLidx2filter = {
            idx - 1: filter_num
            for idx, filter_num in zip(all_bn_id, pruned_filters)
        }

        final_layer_groups = [downsample_idx[-1] - 1
                              ] + [shortcut_idx[-1] - 1, shortcut_idx[-2] - 1]
        mask_groups = [[
            shortcut_idx[sum(self.block_num[:0]) + i]
            for i in range(self.block_num[0])
        ] + [downsample_idx[0]],
                       [
                           shortcut_idx[sum(self.block_num[:1]) + i]
                           for i in range(self.block_num[1])
                       ] + [downsample_idx[1]],
                       [
                           shortcut_idx[sum(self.block_num[:2]) + i]
                           for i in range(self.block_num[2])
                       ] + [downsample_idx[2]],
                       [
                           shortcut_idx[sum(self.block_num[:3]) + i]
                           for i in range(self.block_num[3])
                       ] + [downsample_idx[3]]]
        if self.backbone == "seresnet50" or self.backbone == "seresnet101":
            final_layer_groups.append(shortcut_idx[-3] - 1)

        merge_mask(CBLidx2mask, CBLidx2filter, mask_groups)
        adjust_final_mask(CBLidx2mask, CBLidx2filter, self.model,
                          final_layer_groups)
        for head in head_idx:
            adjust_mask(CBLidx2mask, CBLidx2filter, self.model, head)

        valid_filter = {
            k: v
            for k, v in CBLidx2filter.items() if k + 1 in prune_idx
        }
        channel_str = ",".join(map(lambda x: str(x), valid_filter.values()))
        print(channel_str,
              file=open("buffer/cfg_shortcut_{}.txt".format(self.backbone),
                        "w"))
        m_cfg = {
            'backbone':
            self.backbone,
            'kps':
            self.kps,
            'se_ratio':
            self.se_ratio,
            "first_conv":
            valid_filter[all_bn_id[0] - 1],
            'residual':
            get_residual_channel([filt for _, filt in valid_filter.items()],
                                 self.backbone),
            'channels':
            get_channel_dict([filt for _, filt in valid_filter.items()],
                             self.backbone),
            "head_type":
            "pixel_shuffle",
            "head_channel": [CBLidx2filter[i - 1] for i in head_idx]
        }
        write_cfg(m_cfg, self.compact_model_cfg)
        posenet.build(self.compact_model_cfg)
        compact_model = posenet.model
        self.init_weight(compact_model, self.model, CBLidx2mask, valid_filter,
                         downsample_idx, head_idx)
        torch.save(compact_model.state_dict(), self.compact_model_path)
예제 #4
0
    def run(self, threshold, layer_num):
        print("------------------------- Prune channels first -------------------------------")
        all_bn_id, normal_idx, shortcut_idx, downsample_idx, head_idx = self.obtain_prune_idx(self.model)
        prune_idx = all_bn_id
        sorted_bn = sort_bn(self.model, prune_idx)

        threshold = obtain_bn_threshold(self.model, sorted_bn, threshold / 100)
        pruned_filters, pruned_maskers = obtain_filters_mask(self.model, prune_idx, threshold)
        CBLidx2mask_channel = {idx - 1: mask.astype('float32') for idx, mask in zip(all_bn_id, pruned_maskers)}
        CBLidx2filter = {idx - 1: filter_num for idx, filter_num in zip(all_bn_id, pruned_filters)}

        final_layer_groups = [downsample_idx[-1] - 1] + [shortcut_idx[-1] - 1, shortcut_idx[-2] - 1]
        mask_groups = [
            [shortcut_idx[sum(self.block_num[:0]) + i] for i in range(self.block_num[0])] + [downsample_idx[0]],
            [shortcut_idx[sum(self.block_num[:1]) + i] for i in range(self.block_num[1])] + [downsample_idx[1]],
            [shortcut_idx[sum(self.block_num[:2]) + i] for i in range(self.block_num[2])] + [downsample_idx[2]],
            [shortcut_idx[sum(self.block_num[:3]) + i] for i in range(self.block_num[3])] + [downsample_idx[3]]]
        if self.backbone == "seresnet50" or self.backbone == "seresnet101":
            final_layer_groups.append(shortcut_idx[-3]-1)

        merge_mask(CBLidx2mask_channel, CBLidx2filter, mask_groups)
        adjust_final_mask(CBLidx2mask_channel, CBLidx2filter, self.model, final_layer_groups)
        for head in head_idx:
            adjust_mask(CBLidx2mask_channel, CBLidx2filter, self.model, head)

        valid_filter = {k: v for k, v in CBLidx2filter.items() if k + 1 in prune_idx}
        channel_str = ",".join(map(lambda x: str(x), valid_filter.values()))
        print(channel_str, file=open("buffer/cfg_all_{}.txt".format(self.backbone), "w"))
        m_channel_cfg = {
            'backbone': self.backbone,
            'kps': self.kps,
            'se_ratio': self.se_ratio,
            "first_conv": valid_filter[all_bn_id[0] - 1],
            'residual': get_residual_channel([filt for _, filt in valid_filter.items()], self.backbone),
            'channels': get_channel_dict([filt for _, filt in valid_filter.items()], self.backbone),
            "head_type": "pixel_shuffle",
            "head_channel": [CBLidx2filter[i - 1] for i in head_idx]
        }
        write_cfg(m_channel_cfg, self.compact_model_cfg)
        posenet.build(self.compact_model_cfg)
        compact_channel_model = posenet.model
        self.init_weight_channel(compact_channel_model, self.model, CBLidx2mask_channel, valid_filter, downsample_idx,
                                 head_idx)

        print("\n------------------------- Prune layers -------------------------------")
        all_bn_id, other_idx, shortcut_idx, downsample_idx = obtain_prune_idx_layer(self.model)

        bn_mean = torch.zeros(len(shortcut_idx))
        for i, idx in enumerate(shortcut_idx):
            bn_mean[i] = list(self.model.named_modules())[idx][1].weight.data.abs().mean().clone()
        _, sorted_index_thre = torch.sort(bn_mean)

        prune_shortcuts = torch.tensor(shortcut_idx)[[sorted_index_thre[:layer_num]]]
        prune_shortcuts = [int(x) for x in prune_shortcuts]
        print_mean(bn_mean, shortcut_idx, prune_shortcuts)

        prune_layers = []
        for prune_shortcut in prune_shortcuts:
            target_idx = all_bn_id.index(prune_shortcut)
            for i in range(3):
                prune_layers.append(all_bn_id[target_idx - i])

        CBLidx2mask_layer = obtain_layer_filters_mask(compact_channel_model, all_bn_id, prune_layers)

        pruned_locations = self.obtain_block_idx(shortcut_idx, prune_shortcuts)
        blocks = self.block_num
        for pruned_location in pruned_locations:
            blocks[pruned_location] -= 1

        m_layer_cfg = {
            'backbone': self.backbone,
            'keypoints': self.kps,
            'se_ratio': self.se_ratio,
            "first_conv": m_channel_cfg["first_conv"],
            'residual': m_channel_cfg["residual"],
            'channels': obtain_all_prune_channels(sorted_index_thre[:layer_num].tolist(), m_channel_cfg["channels"]),
            "head_type": self.head_type,
            "head_channel": m_channel_cfg["head_channel"],
        }
        write_cfg(m_layer_cfg, self.compact_model_cfg)
        posenet.build(self.compact_model_cfg)
        compact_layer_model = posenet.model
        # compact_all_bn = [idx for idx, mod in enumerate(list(compact_model.named_modules()))
        #                   if isinstance(mod[1], torch.nn.BatchNorm2d)]
        compact_all_bn_idx, compact_other_idx, compact_shortcut_idx, compact_downsample_idx = \
            obtain_prune_idx_layer(compact_layer_model)
        init_weights_from_loose_model_layer(compact_layer_model, compact_channel_model, CBLidx2mask_layer, compact_all_bn_idx)

        torch.save(compact_layer_model.state_dict(), self.compact_model_path)