def baseline_flops(num_classes, resnet_multi=0.7, vgg_multi=0.71): if resnet_multi is not None: model = resnet56(num_classes) res_baseline_flops = compute_conv_flops(model, cuda=True) print( f"Baseline FLOPs of CIFAR-{num_classes} ResNet-56: {res_baseline_flops:,}, 50% FLOPs: {res_baseline_flops / 2:,}" ) multi = resnet_multi model = resnet56(num_classes, width_multiplier=multi) flops = compute_conv_flops(model, cuda=True) print( f"FLOPs of CIFAR-{num_classes} ResNet-56 {multi}x: {flops:,}, FLOPs ratio: {flops / res_baseline_flops}" ) print() # from compute_flops import count_model_param_flops # flops_original_imple = count_model_param_flops(model, multiply_adds=False) # print(flops_original_imple) if vgg_multi is not None: model = vgg16_linear(num_classes) vgg_baseline_flops = compute_conv_flops(model) print( f"Baseline FLOPs of CIFAR-{num_classes} VGG-16: {vgg_baseline_flops:,}, 50% FLOPs: {vgg_baseline_flops / 2:,}" ) multi = vgg_multi model = vgg16_linear(num_classes, width_multiplier=multi) flops = compute_conv_flops(model) print( f"FLOPs of CIFAR-{num_classes} VGG-16 {multi}x: {flops:,}, FLOPs ratio: {flops / vgg_baseline_flops}" ) print()
def main(): # reproducibility torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False parser = _get_parser() args = parser.parse_args() if args.dataset == "cifar10": num_classes = 10 elif args.dataset == 'cifar100': num_classes = 100 else: raise ValueError(f"Unrecognized dataset {args.dataset}") if not os.path.exists(args.save): os.makedirs(args.save) print(args) print(f"Current git hash: {common.get_git_id()}") if not os.path.isfile(args.model): raise ValueError("=> no checkpoint found at '{}'".format(args.model)) checkpoint: Dict[str, Any] = torch.load(args.model) print( f"=> Loading the model...\n=> Epoch: {checkpoint['epoch']}, Acc.: {checkpoint['best_prec1']}" ) # build the sparse model sparse_model: VGG = vgg16_linear(num_classes=num_classes, gate=args.gate) sparse_model.load_state_dict(checkpoint['state_dict']) saved_model = prune_vgg(num_classes=num_classes, sparse_model=sparse_model, pruning_strategy=args.pruning_strategy, sanity_check=True, prune_mode=args.prune_mode) # compute FLOPs baseline_flops = common.compute_conv_flops( vgg16_linear(num_classes=num_classes, gate=False)) saved_flops = common.compute_conv_flops(saved_model) print(f"Unpruned FLOPs: {baseline_flops:,}") print(f"Saved FLOPs: {saved_flops:,}") print(f"FLOPs ratio: {saved_flops / baseline_flops:,}") # save state_dict torch.save( { 'state_dict': saved_model.state_dict(), 'cfg': saved_model.config() }, os.path.join(args.save, f'pruned_{args.pruning_strategy}.pth.tar'))
def calculate_flops(current_model): if args.arch == "resnet56": model_ref = models.resnet_expand.resnet56(num_classes=num_classes) elif args.arch == 'vgg16_linear': model_ref = vgg16_linear(num_classes=num_classes, gate=False) else: raise NotImplementedError() current_flops = compute_conv_flops(current_model.cpu()) ref_flops = compute_conv_flops(model_ref.cpu()) flops_ratio = current_flops / ref_flops print("FLOPs remains {}".format(flops_ratio))
def prune_while_training(model: nn.Module, arch: str, prune_mode: str, num_classes: int): if arch == "resnet56": from resprune_gate import prune_resnet from models.resnet_expand import resnet56 as resnet50_expand saved_model_grad = prune_resnet(sparse_model=model, pruning_strategy='grad', sanity_check=False, prune_mode=prune_mode, num_classes=num_classes) saved_model_fixed = prune_resnet(sparse_model=model, pruning_strategy='fixed', sanity_check=False, prune_mode=prune_mode, num_classes=num_classes) baseline_model = resnet50_expand(num_classes=num_classes, gate=False, aux_fc=False) elif arch == 'vgg16_linear': from vggprune_gate import prune_vgg from models import vgg16_linear saved_model_grad = prune_vgg(num_classes=num_classes, sparse_model=model, prune_mode=prune_mode, sanity_check=False, pruning_strategy='grad') saved_model_fixed = prune_vgg(num_classes=num_classes, sparse_model=model, prune_mode=prune_mode, sanity_check=False, pruning_strategy='fixed') baseline_model = vgg16_linear(num_classes=num_classes, gate=False) else: # not available raise NotImplementedError(f"do not support arch {arch}") saved_flops_grad = compute_conv_flops(saved_model_grad, cuda=True) saved_flops_fixed = compute_conv_flops(saved_model_fixed, cuda=True) baseline_flops = compute_conv_flops(baseline_model, cuda=True) return saved_flops_grad, saved_flops_fixed, baseline_flops
if isinstance(m.expand_layer, Identity): continue mask = bn3_masks[i] assert mask[1].shape[0] == m.expand_layer.idx.shape[0] m.expand_layer.idx = np.argwhere( mask[1].clone().cpu().numpy()).squeeze().reshape(-1) else: raise NotImplementedError("Key bn3_masks expected in checkpoint.") elif args.arch == "vgg16_linear": model = models.__dict__[args.arch](num_classes=num_classes, cfg=checkpoint['cfg']) else: raise NotImplementedError(f"Do not support {args.arch} for retrain.") training_flops = compute_conv_flops(model, cuda=True) print(f"Training model. FLOPs: {training_flops:,}") def compute_flops_weight(cuda=False): # compute the flops weight for each layer in advance print("Computing the FLOPs weight...") flops_weight = model.compute_flops_weight(cuda=cuda) flops_weight_string_builder: typing.List[str] = [] for fw in flops_weight: flops_weight_string_builder.append(",".join(str(w) for w in fw)) flops_weight_string = "\n".join(flops_weight_string_builder) print("FLOPs weight:") print(flops_weight_string) print()