Пример #1
0
def run():
    pack, GBNs = get_pack()

    cloned, _ = clone_model(pack.net)
    BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module,
                                           torch.randn(1, 3, 32, 32).cuda())
    print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
    print('%.3f M' % (BASE_PARAM / 1e6))
    del cloned

    prune(pack, GBNs, BASE_FLOPS, BASE_PARAM)

    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.classifier = FinalLinearObserver(
        pack.net.module.classifier)
    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net)

    pack.optimizer = optim.SGD(pack.net.parameters(),
                               lr=1,
                               momentum=cfg.train.momentum,
                               weight_decay=cfg.train.weight_decay,
                               nesterov=cfg.train.nesterov)

    _ = finetune(pack,
                 lr_min=cfg.gbn.lr_min,
                 lr_max=cfg.gbn.lr_max,
                 T=cfg.gbn.finetune_epoch)
Пример #2
0
def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    cloned.module.classifier = FinalLinearObserver(cloned.module.classifier)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
    del cloned
    del cloned_pack
    
    return flops, params
Пример #3
0
def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    # cloned.module.classifier[0] = FinalLinearObserver(cloned.module.classifier[0])
    cloned.module.fc = FinalLinearObserver(cloned.module.fc)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops = []
    params = []
    for res in cfg.model.resolution:
        f, p = analyse_model(cloned_pack.net.module,
                             torch.randn(1, 3, res, res).cuda())
        flops.append(f)
        params.append(p)
    del cloned
    del cloned_pack

    return flops, params
Пример #4
0
#             group_id = uuid.uuid1()
#             if len(masks) > 1:
#                 for mk in masks:
#                     mk.set_groupid(group_id)
#             masks = []

bottleneck_set_group(pack.net)

def clone_model(net):
    model = get_model()
    gbns = GatedBatchNorm2d.transform(model.module)
    model.load_state_dict(net.state_dict())
    return model, gbns

cloned, _ = clone_model(pack.net)
BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, cfg.model.resolution, cfg.model.resolution).cuda())
print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
print('%.3f M' % (BASE_PARAM / 1e6))
del cloned

def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    # cloned.module.classifier[0] = FinalLinearObserver(cloned.module.classifier[0])
    cloned.module.fc = FinalLinearObserver(cloned.module.fc)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, cfg.model.resolution, cfg.model.resolution).cuda())
    del cloned
Пример #5
0
#             masks = []

bottleneck_set_group(pack.net)


def clone_model(net):
    model = get_model()
    gbns = GatedBatchNorm2d.transform(model.module)
    model.load_state_dict(net.state_dict())
    return model, gbns


cloned, _ = clone_model(pack.net)
BASE_FLOPS, BASE_PARAM = [], []
for res in cfg.model.resolution:
    f, p = analyse_model(cloned.module, torch.randn(1, 3, res, res).cuda())
    BASE_FLOPS.append(f)
    BASE_PARAM.append(p)
    print('%.3f MFLOPS' % (f / 1e6))
    print('%.3f M' % (p / 1e6))
del cloned


def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    # cloned.module.classifier[0] = FinalLinearObserver(cloned.module.classifier[0])
    cloned.module.fc = FinalLinearObserver(cloned.module.fc)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)