Ejemplo n.º 1
0
                                      shuffle=True,
                                      train=True,
                                      download=download)
    val_loader = get_dataloader_ddp(args.data_dir,
                                    batch_size,
                                    num_workers=args.num_workers,
                                    shuffle=False,
                                    train=False,
                                    download=download)

    model = MyNet()
    model = load_weights(model, args.pretrained)
    input_signature = torch.randn([1, 3, 32, 32], dtype=torch.float32)
    input_signature = input_signature.to(device)
    model = model.to(device)
    pruning_runner = get_pruning_runner(model, input_signature, 'iterative')

    model = pruning_runner.prune(removal_ratio=args.sparsity, mode='sparse')
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=True)
    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)
    best_acc1 = 0
    for epoch in range(args.epochs):
Ejemplo n.º 2
0
    if args.subset_len:
        data_loader = get_subnet_dataloader(args.data_dir,
                                            batch_size,
                                            args.subset_len,
                                            num_workers=args.num_workers,
                                            shuffle=False,
                                            train=False,
                                            download=download)
    else:
        data_loader = get_dataloader(args.data_dir,
                                     args.batch_size,
                                     num_workers=args.num_workers,
                                     shuffle=False,
                                     train=False,
                                     download=download)

    model = MyNet()
    model = load_weights(model, args.pretrained)
    input_signature = torch.randn([1, 3, 32, 32], dtype=torch.float32)
    input_signature = input_signature.to(device)
    model = model.to(device)
    pruning_runner = get_pruning_runner(model, input_signature, 'one_step')

    pruning_runner.search(gpus=gpus,
                          calibration_fn=calibration_fn,
                          calib_args=(data_loader, ),
                          num_subnet=args.num_subnet,
                          removal_ratio=args.sparsity,
                          eval_fn=eval_fn,
                          eval_args=(data_loader, ))