def run(): pack, GBNs = get_pack() cloned, _ = clone_model(pack.net) BASE_FLOPS, BASE_PARAM = analyse_model(cloned.module, torch.randn(1, 3, 32, 32).cuda()) print('%.3f MFLOPS' % (BASE_FLOPS / 1e6)) print('%.3f M' % (BASE_PARAM / 1e6)) del cloned prune(pack, GBNs, BASE_FLOPS, BASE_PARAM) _ = Conv2dObserver.transform(pack.net.module) pack.net.module.classifier = FinalLinearObserver( pack.net.module.classifier) Meltable.observe(pack, 0.001) Meltable.melt_all(pack.net) pack.optimizer = optim.SGD(pack.net.parameters(), lr=1, momentum=cfg.train.momentum, weight_decay=cfg.train.weight_decay, nesterov=cfg.train.nesterov) _ = finetune(pack, lr_min=cfg.gbn.lr_min, lr_max=cfg.gbn.lr_max, T=cfg.gbn.finetune_epoch)
def eval_prune(pack): cloned, _ = clone_model(pack.net) _ = Conv2dObserver.transform(cloned.module) cloned.module.classifier = FinalLinearObserver(cloned.module.classifier) cloned_pack = dotdict(pack.copy()) cloned_pack.net = cloned Meltable.observe(cloned_pack, 0.001) Meltable.melt_all(cloned_pack.net) flops, params = analyse_model(cloned_pack.net.module, torch.randn(1, 3, 32, 32).cuda()) del cloned del cloned_pack return flops, params
def eval_prune(pack): cloned, _ = clone_model(pack.net) _ = Conv2dObserver.transform(cloned.module) # cloned.module.classifier[0] = FinalLinearObserver(cloned.module.classifier[0]) cloned.module.fc = FinalLinearObserver(cloned.module.fc) cloned_pack = dotdict(pack.copy()) cloned_pack.net = cloned Meltable.observe(cloned_pack, 0.001) Meltable.melt_all(cloned_pack.net) flops = [] params = [] for res in cfg.model.resolution: f, p = analyse_model(cloned_pack.net.module, torch.randn(1, 3, res, res).cuda()) flops.append(f) params.append(p) del cloned del cloned_pack return flops, params