def test_pruning_external(self): from lpot.experimental import common from lpot import Pruning prune = Pruning('fake.yaml') datasets = DATASETS('pytorch') dummy_dataset = datasets['dummy'](shape=(100, 3, 224, 224), low=0., high=1., label=True) dummy_dataloader = PyTorchDataLoader(dummy_dataset) def training_func_for_lpot(model): epochs = 16 iters = 30 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) for nepoch in range(epochs): model.train() cnt = 0 prune.on_epoch_begin(nepoch) for image, target in dummy_dataloader: prune.on_batch_begin(cnt) print('.', end='') cnt += 1 output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() prune.on_batch_end() if cnt >= iters: break prune.on_epoch_end() prune.model = common.Model(self.model) prune.pruning_func = training_func_for_lpot prune.eval_dataloader = dummy_dataloader prune.train_dataloader = dummy_dataloader _ = prune(common.Model(self.model), \ train_dataloader=dummy_dataloader, \ pruning_func=training_func_for_lpot, \ eval_dataloader=dummy_dataloader)
def train(args, train_dataset, model, tokenizer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) def train_func(model): return take_train_steps(args, model, tokenizer, train_dataloader, prune) def eval_func(model): return take_eval_steps(args, model, tokenizer, prune) if args.prune: from lpot.experimental import Pruning, common prune = Pruning(args.config) prune.model = common.Model(model) prune.train_dataloader = train_dataloader prune.pruning_func = train_func prune.eval_dataloader = train_dataloader prune.eval_func = eval_func model = prune() torch.save(model, args.output_model)
def test_pruning(self): from lpot.experimental import Pruning, common prune = Pruning('fake.yaml') dummy_dataset = PyTorchDummyDataset([tuple([100, 3, 256, 256])]) dummy_dataloader = PyTorchDataLoader(dummy_dataset) def training_func_for_lpot(model): epochs = 16 iters = 30 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) for nepoch in range(epochs): model.train() cnt = 0 prune.on_epoch_begin(nepoch) for image, target in dummy_dataloader: prune.on_batch_begin(cnt) print('.', end='') cnt += 1 output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() prune.on_batch_end() if cnt >= iters: break prune.on_epoch_end() dummy_dataset = PyTorchDummyDataset(tuple([100, 3, 256, 256]), label=True) dummy_dataloader = PyTorchDataLoader(dummy_dataset) prune.model = common.Model(self.model) prune.pruning_func = training_func_for_lpot prune.eval_dataloader = dummy_dataloader prune.train_dataloader = dummy_dataloader _ = prune()