shuffle=True, train=True, download=download) val_loader = get_dataloader_ddp(args.data_dir, batch_size, num_workers=args.num_workers, shuffle=False, train=False, download=download) model = MyNet() model = load_weights(model, args.pretrained) input_signature = torch.randn([1, 3, 32, 32], dtype=torch.float32) input_signature = input_signature.to(device) model = model.to(device) pruning_runner = get_pruning_runner(model, input_signature, 'iterative') model = pruning_runner.prune(removal_ratio=args.sparsity, mode='sparse') model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) criterion = torch.nn.CrossEntropyLoss().cuda() optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.epochs) best_acc1 = 0 for epoch in range(args.epochs):
if args.subset_len: data_loader = get_subnet_dataloader(args.data_dir, batch_size, args.subset_len, num_workers=args.num_workers, shuffle=False, train=False, download=download) else: data_loader = get_dataloader(args.data_dir, args.batch_size, num_workers=args.num_workers, shuffle=False, train=False, download=download) model = MyNet() model = load_weights(model, args.pretrained) input_signature = torch.randn([1, 3, 32, 32], dtype=torch.float32) input_signature = input_signature.to(device) model = model.to(device) pruning_runner = get_pruning_runner(model, input_signature, 'one_step') pruning_runner.search(gpus=gpus, calibration_fn=calibration_fn, calib_args=(data_loader, ), num_subnet=args.num_subnet, removal_ratio=args.sparsity, eval_fn=eval_fn, eval_args=(data_loader, ))