コード例 #1
0
def run():
    pack, GBNs = get_pack()

    cloned, _ = clone_model(pack.net)
    BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module,
                                           torch.randn(1, 3, 32, 32).cuda())
    print('%.3f MFLOPS' % (BASE_FLOPS / 1e6))
    print('%.3f M' % (BASE_PARAM / 1e6))
    del cloned

    prune(pack, GBNs, BASE_FLOPS, BASE_PARAM)

    _ = Conv2dObserver.transform(pack.net.module)
    pack.net.module.classifier = FinalLinearObserver(
        pack.net.module.classifier)
    Meltable.observe(pack, 0.001)
    Meltable.melt_all(pack.net)

    pack.optimizer = optim.SGD(pack.net.parameters(),
                               lr=1,
                               momentum=cfg.train.momentum,
                               weight_decay=cfg.train.weight_decay,
                               nesterov=cfg.train.nesterov)

    _ = finetune(pack,
                 lr_min=cfg.gbn.lr_min,
                 lr_max=cfg.gbn.lr_max,
                 T=cfg.gbn.finetune_epoch)
コード例 #2
0
def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    cloned.module.classifier = FinalLinearObserver(cloned.module.classifier)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda())
    del cloned
    del cloned_pack
    
    return flops, params
コード例 #3
0
def eval_prune(pack):
    cloned, _ = clone_model(pack.net)
    _ = Conv2dObserver.transform(cloned.module)
    # cloned.module.classifier[0] = FinalLinearObserver(cloned.module.classifier[0])
    cloned.module.fc = FinalLinearObserver(cloned.module.fc)
    cloned_pack = dotdict(pack.copy())
    cloned_pack.net = cloned
    Meltable.observe(cloned_pack, 0.001)
    Meltable.melt_all(cloned_pack.net)
    flops = []
    params = []
    for res in cfg.model.resolution:
        f, p = analyse_model(cloned_pack.net.module,
                             torch.randn(1, 3, res, res).cuda())
        flops.append(f)
        params.append(p)
    del cloned
    del cloned_pack

    return flops, params