def baseline_flops(num_classes, resnet_multi=0.7, vgg_multi=0.71): if resnet_multi is not None: model = resnet56(num_classes) res_baseline_flops = compute_conv_flops(model, cuda=True) print( f"Baseline FLOPs of CIFAR-{num_classes} ResNet-56: {res_baseline_flops:,}, 50% FLOPs: {res_baseline_flops / 2:,}" ) multi = resnet_multi model = resnet56(num_classes, width_multiplier=multi) flops = compute_conv_flops(model, cuda=True) print( f"FLOPs of CIFAR-{num_classes} ResNet-56 {multi}x: {flops:,}, FLOPs ratio: {flops / res_baseline_flops}" ) print() # from compute_flops import count_model_param_flops # flops_original_imple = count_model_param_flops(model, multiply_adds=False) # print(flops_original_imple) if vgg_multi is not None: model = vgg16_linear(num_classes) vgg_baseline_flops = compute_conv_flops(model) print( f"Baseline FLOPs of CIFAR-{num_classes} VGG-16: {vgg_baseline_flops:,}, 50% FLOPs: {vgg_baseline_flops / 2:,}" ) multi = vgg_multi model = vgg16_linear(num_classes, width_multiplier=multi) flops = compute_conv_flops(model) print( f"FLOPs of CIFAR-{num_classes} VGG-16 {multi}x: {flops:,}, FLOPs ratio: {flops / vgg_baseline_flops}" ) print()
def pruning_summary_resnet56(model, num_classes): model_ref = resnet56(num_classes) if hasattr(model, "module"): # remove parallel wrapper model = model.module pruning_layers = [] for (name, m), (name_ref, m_ref) in zip(model.named_modules(), model_ref.named_modules()): if isinstance(m, torch.nn.BatchNorm2d) or isinstance( m, torch.nn.BatchNorm1d): assert len(m_ref.weight.shape) == 1 assert len(m.weight.shape) == 1 pruning_layers.append( "{}: original shape: {}, pruned shape: {}".format( name, m_ref.weight.shape[0], m.weight.shape[0])) return "\n".join(pruning_layers)
print(args) output_name = "pruned_{}".format( args.pruning_strategy ) if args.pruning_strategy != "percent" else "pruned_{}".format(args.percent) if not os.path.exists(args.save): os.makedirs(args.save) if str.lower(args.dataset) == "cifar100": num_classes = 100 elif str.lower(args.dataset) == "cifar10": num_classes = 10 else: raise NotImplementedError("do not support dataset {}".format(args.dataset)) model = resnet56(num_classes, aux_fc=False) if args.model: if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format( args.model, checkpoint['epoch'], best_prec1)) else: raise ValueError("=> no checkpoint found at '{}'".format(args.resume)) if args.dataset == 'cifar10':