Exemplo n.º 1
0
def baseline_flops(num_classes, resnet_multi=0.7, vgg_multi=0.71):
    if resnet_multi is not None:
        model = resnet56(num_classes)
        res_baseline_flops = compute_conv_flops(model, cuda=True)
        print(
            f"Baseline FLOPs of CIFAR-{num_classes} ResNet-56: {res_baseline_flops:,}, 50% FLOPs: {res_baseline_flops / 2:,}"
        )

        multi = resnet_multi
        model = resnet56(num_classes, width_multiplier=multi)
        flops = compute_conv_flops(model, cuda=True)
        print(
            f"FLOPs of CIFAR-{num_classes} ResNet-56 {multi}x: {flops:,}, FLOPs ratio: {flops / res_baseline_flops}"
        )
        print()

    # from compute_flops import count_model_param_flops
    # flops_original_imple = count_model_param_flops(model, multiply_adds=False)
    # print(flops_original_imple)

    if vgg_multi is not None:
        model = vgg16_linear(num_classes)
        vgg_baseline_flops = compute_conv_flops(model)
        print(
            f"Baseline FLOPs of CIFAR-{num_classes} VGG-16: {vgg_baseline_flops:,}, 50% FLOPs: {vgg_baseline_flops / 2:,}"
        )

        multi = vgg_multi
        model = vgg16_linear(num_classes, width_multiplier=multi)
        flops = compute_conv_flops(model)
        print(
            f"FLOPs of CIFAR-{num_classes} VGG-16 {multi}x: {flops:,}, FLOPs ratio: {flops / vgg_baseline_flops}"
        )
        print()
Exemplo n.º 2
0
def pruning_summary_resnet56(model, num_classes):
    model_ref = resnet56(num_classes)
    if hasattr(model, "module"):
        # remove parallel wrapper
        model = model.module

    pruning_layers = []
    for (name, m), (name_ref, m_ref) in zip(model.named_modules(),
                                            model_ref.named_modules()):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(
                m, torch.nn.BatchNorm1d):
            assert len(m_ref.weight.shape) == 1
            assert len(m.weight.shape) == 1
            pruning_layers.append(
                "{}: original shape: {}, pruned shape: {}".format(
                    name, m_ref.weight.shape[0], m.weight.shape[0]))
    return "\n".join(pruning_layers)
Exemplo n.º 3
0
print(args)

output_name = "pruned_{}".format(
    args.pruning_strategy
) if args.pruning_strategy != "percent" else "pruned_{}".format(args.percent)

if not os.path.exists(args.save):
    os.makedirs(args.save)

if str.lower(args.dataset) == "cifar100":
    num_classes = 100
elif str.lower(args.dataset) == "cifar10":
    num_classes = 10
else:
    raise NotImplementedError("do not support dataset {}".format(args.dataset))
model = resnet56(num_classes, aux_fc=False)

if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model,
                                map_location=lambda storage, loc: storage)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(
            args.model, checkpoint['epoch'], best_prec1))
    else:
        raise ValueError("=> no checkpoint found at '{}'".format(args.resume))

if args.dataset == 'cifar10':