コード例 #1
0
def baseline_flops(num_classes, resnet_multi=0.7, vgg_multi=0.71):
    if resnet_multi is not None:
        model = resnet56(num_classes)
        res_baseline_flops = compute_conv_flops(model, cuda=True)
        print(
            f"Baseline FLOPs of CIFAR-{num_classes} ResNet-56: {res_baseline_flops:,}, 50% FLOPs: {res_baseline_flops / 2:,}"
        )

        multi = resnet_multi
        model = resnet56(num_classes, width_multiplier=multi)
        flops = compute_conv_flops(model, cuda=True)
        print(
            f"FLOPs of CIFAR-{num_classes} ResNet-56 {multi}x: {flops:,}, FLOPs ratio: {flops / res_baseline_flops}"
        )
        print()

    # from compute_flops import count_model_param_flops
    # flops_original_imple = count_model_param_flops(model, multiply_adds=False)
    # print(flops_original_imple)

    if vgg_multi is not None:
        model = vgg16_linear(num_classes)
        vgg_baseline_flops = compute_conv_flops(model)
        print(
            f"Baseline FLOPs of CIFAR-{num_classes} VGG-16: {vgg_baseline_flops:,}, 50% FLOPs: {vgg_baseline_flops / 2:,}"
        )

        multi = vgg_multi
        model = vgg16_linear(num_classes, width_multiplier=multi)
        flops = compute_conv_flops(model)
        print(
            f"FLOPs of CIFAR-{num_classes} VGG-16 {multi}x: {flops:,}, FLOPs ratio: {flops / vgg_baseline_flops}"
        )
        print()
コード例 #2
0
def main():
    # reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    parser = _get_parser()
    args = parser.parse_args()

    if args.dataset == "cifar10":
        num_classes = 10
    elif args.dataset == 'cifar100':
        num_classes = 100
    else:
        raise ValueError(f"Unrecognized dataset {args.dataset}")

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    print(args)
    print(f"Current git hash: {common.get_git_id()}")

    if not os.path.isfile(args.model):
        raise ValueError("=> no checkpoint found at '{}'".format(args.model))

    checkpoint: Dict[str, Any] = torch.load(args.model)
    print(
        f"=> Loading the model...\n=> Epoch: {checkpoint['epoch']}, Acc.: {checkpoint['best_prec1']}"
    )

    # build the sparse model
    sparse_model: VGG = vgg16_linear(num_classes=num_classes, gate=args.gate)
    sparse_model.load_state_dict(checkpoint['state_dict'])

    saved_model = prune_vgg(num_classes=num_classes,
                            sparse_model=sparse_model,
                            pruning_strategy=args.pruning_strategy,
                            sanity_check=True,
                            prune_mode=args.prune_mode)

    # compute FLOPs
    baseline_flops = common.compute_conv_flops(
        vgg16_linear(num_classes=num_classes, gate=False))
    saved_flops = common.compute_conv_flops(saved_model)

    print(f"Unpruned FLOPs: {baseline_flops:,}")
    print(f"Saved FLOPs: {saved_flops:,}")
    print(f"FLOPs ratio: {saved_flops / baseline_flops:,}")

    # save state_dict
    torch.save(
        {
            'state_dict': saved_model.state_dict(),
            'cfg': saved_model.config()
        }, os.path.join(args.save, f'pruned_{args.pruning_strategy}.pth.tar'))
コード例 #3
0
ファイル: main_finetune.py プロジェクト: xue1234730/Prune
def calculate_flops(current_model):
    if args.arch == "resnet56":
        model_ref = models.resnet_expand.resnet56(num_classes=num_classes)
    elif args.arch == 'vgg16_linear':
        model_ref = vgg16_linear(num_classes=num_classes, gate=False)
    else:
        raise NotImplementedError()

    current_flops = compute_conv_flops(current_model.cpu())
    ref_flops = compute_conv_flops(model_ref.cpu())
    flops_ratio = current_flops / ref_flops

    print("FLOPs remains {}".format(flops_ratio))
コード例 #4
0
def prune_while_training(model: nn.Module, arch: str, prune_mode: str,
                         num_classes: int):
    if arch == "resnet56":
        from resprune_gate import prune_resnet
        from models.resnet_expand import resnet56 as resnet50_expand
        saved_model_grad = prune_resnet(sparse_model=model,
                                        pruning_strategy='grad',
                                        sanity_check=False,
                                        prune_mode=prune_mode,
                                        num_classes=num_classes)
        saved_model_fixed = prune_resnet(sparse_model=model,
                                         pruning_strategy='fixed',
                                         sanity_check=False,
                                         prune_mode=prune_mode,
                                         num_classes=num_classes)
        baseline_model = resnet50_expand(num_classes=num_classes,
                                         gate=False,
                                         aux_fc=False)
    elif arch == 'vgg16_linear':
        from vggprune_gate import prune_vgg
        from models import vgg16_linear

        saved_model_grad = prune_vgg(num_classes=num_classes,
                                     sparse_model=model,
                                     prune_mode=prune_mode,
                                     sanity_check=False,
                                     pruning_strategy='grad')
        saved_model_fixed = prune_vgg(num_classes=num_classes,
                                      sparse_model=model,
                                      prune_mode=prune_mode,
                                      sanity_check=False,
                                      pruning_strategy='fixed')
        baseline_model = vgg16_linear(num_classes=num_classes, gate=False)
    else:
        # not available
        raise NotImplementedError(f"do not support arch {arch}")

    saved_flops_grad = compute_conv_flops(saved_model_grad, cuda=True)
    saved_flops_fixed = compute_conv_flops(saved_model_fixed, cuda=True)
    baseline_flops = compute_conv_flops(baseline_model, cuda=True)

    return saved_flops_grad, saved_flops_fixed, baseline_flops
コード例 #5
0
                    if isinstance(m.expand_layer, Identity):
                        continue
                    mask = bn3_masks[i]
                    assert mask[1].shape[0] == m.expand_layer.idx.shape[0]
                    m.expand_layer.idx = np.argwhere(
                        mask[1].clone().cpu().numpy()).squeeze().reshape(-1)
        else:
            raise NotImplementedError("Key bn3_masks expected in checkpoint.")

    elif args.arch == "vgg16_linear":
        model = models.__dict__[args.arch](num_classes=num_classes,
                                           cfg=checkpoint['cfg'])
    else:
        raise NotImplementedError(f"Do not support {args.arch} for retrain.")

training_flops = compute_conv_flops(model, cuda=True)
print(f"Training model. FLOPs: {training_flops:,}")


def compute_flops_weight(cuda=False):
    # compute the flops weight for each layer in advance
    print("Computing the FLOPs weight...")
    flops_weight = model.compute_flops_weight(cuda=cuda)
    flops_weight_string_builder: typing.List[str] = []
    for fw in flops_weight:
        flops_weight_string_builder.append(",".join(str(w) for w in fw))
    flops_weight_string = "\n".join(flops_weight_string_builder)
    print("FLOPs weight:")
    print(flops_weight_string)
    print()