def test_pruning(self): from lpot.experimental import Pruning, common prune = Pruning('fake.yaml') dummy_dataset = PyTorchDummyDataset([tuple([100, 3, 256, 256])]) dummy_dataloader = PyTorchDataLoader(dummy_dataset) def training_func_for_lpot(model): epochs = 16 iters = 30 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) for nepoch in range(epochs): model.train() cnt = 0 prune.on_epoch_begin(nepoch) for image, target in dummy_dataloader: prune.on_batch_begin(cnt) print('.', end='') cnt += 1 output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() prune.on_batch_end() if cnt >= iters: break prune.on_epoch_end() dummy_dataset = PyTorchDummyDataset(tuple([100, 3, 256, 256]), label=True) dummy_dataloader = PyTorchDataLoader(dummy_dataset) prune.model = common.Model(self.model) prune.q_func = training_func_for_lpot prune.eval_dataloader = dummy_dataloader _ = prune()
def main_worker(gpu, args): global best_acc1 print("Use CPU: {} for training".format(gpu)) if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True, quantize=False) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return if args.prune: from lpot.experimental import Pruning, common prune = Pruning(args.config) def training_func_for_lpot(model): epochs = 16 iters = 30 optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) for nepoch in range(epochs): model.train() cnt = 0 prune.on_epoch_begin(nepoch) for image, target in train_loader: prune.on_batch_begin(cnt) print('.', end='') cnt += 1 output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() prune.on_batch_end() if cnt >= iters: break prune.on_epoch_end() if nepoch > 3: # Freeze quantizer parameters model.apply(torch.quantization.disable_observer) if nepoch > 2: # Freeze batch norm mean and variance estimates model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) validate(val_loader, model, criterion, args) return prune.model = common.Model(model) prune.q_func = training_func_for_lpot q_model = prune() return