def train_once(mb, net, trainloader, testloader, writer, config, ckpt_path, learning_rate, weight_decay, num_epochs, iteration, logger): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) lr_schedule = {0: learning_rate, int(num_epochs * 0.5): learning_rate * 0.1, int(num_epochs * 0.75): learning_rate * 0.01} lr_scheduler = PresetLRScheduler(lr_schedule) best_acc = 0 best_epoch = 0 for epoch in range(num_epochs): train(net, trainloader, optimizer, criterion, lr_scheduler, epoch, writer, iteration=iteration) test_acc = test(net, testloader, criterion, epoch, writer, iteration) if test_acc > best_acc: print('Saving..') state = { 'net': net, 'acc': test_acc, 'epoch': epoch, 'args': config, # 'mask': mb.masks, # 'ratio': mb.get_ratio_at_each_layer() } path = os.path.join(ckpt_path, 'finetune_%s_%s%s_r%s_it%d_best.pth.tar' % (config.dataset, config.network, config.depth, config.target_ratio, iteration)) torch.save(state, path) best_acc = test_acc best_epoch = epoch logger.info('Iteration [%d], best acc: %.4f, epoch: %d' % (iteration, best_acc, best_epoch))
def fine_tune_model(self, trainloader, testloader, criterion, optim, learning_rate, weight_decay, nepochs=10, device='cuda'): self.model = self.model.train() optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4) lr_schedule = {0: learning_rate, int(nepochs * 0.5): learning_rate * 0.1, int(nepochs * 0.75): learning_rate * 0.01} lr_scheduler = PresetLRScheduler(lr_schedule) best_test_acc, best_test_loss = 0, 100 iterations = 0 for epoch in range(nepochs): self.model = self.model.train() correct = 0 total = 0 all_loss = 0 lr_scheduler(optimizer, epoch) desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % ( lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) for batch_idx, (inputs, targets) in prog_bar: optimizer.zero_grad() inputs, targets = inputs.to(device), targets.to(device) outputs = self.model(inputs) loss = criterion(outputs, targets) self.writer.add_scalar('train_%d/loss' % self.iter, loss.item(), iterations) iterations += 1 all_loss += loss.item() loss.backward() optimizer.step() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() desc = ('[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (epoch, lr_scheduler.get_lr(optimizer), weight_decay, all_loss / (batch_idx + 1), 100. * correct / total, correct, total)) prog_bar.set_description(desc, refresh=True) test_loss, test_acc = self.test_model(testloader, criterion, device) best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss best_test_acc = max(test_acc, best_test_acc) print('** Finetuning finished. Stabilizing batch norm and test again!') stablize_bn(self.model, trainloader) test_loss, test_acc = self.test_model(testloader, criterion, device) best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss best_test_acc = max(test_acc, best_test_acc) return best_test_loss, best_test_acc
def fine_tune_model(self, trainloader, testloader, criterion, optim, learning_rate, weight_decay, nepochs=10, device='cuda'): self.model = self.model.train() self.model = self.model.cpu() self.model = self.model.to(device) optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4) if self.config.dataset == "cifar10": lr_schedule = { 0: learning_rate, int(nepochs * 0.5): learning_rate * 0.1, int(nepochs * 0.75): learning_rate * 0.01 } elif self.config.dataset == "imagenet": lr_schedule = { 0: learning_rate, 30: learning_rate * 0.1, 60: learning_rate * 0.01 } lr_scheduler = PresetLRScheduler(lr_schedule) best_test_acc, best_test_loss = 0, 100 iterations = 0 for epoch in range(nepochs): self.model = self.model.train() correct = 0 total = 0 all_loss = 0 lr_scheduler(optimizer, epoch) desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (lr_scheduler.get_lr(optimizer), 0, 0, correct, total)) prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) for batch_idx, (inputs, targets) in prog_bar: optimizer.zero_grad() inputs, targets = inputs.to(device), targets.to(device) outputs = self.model(inputs) loss = criterion(outputs, targets) self.writer.add_scalar('train_%d/loss' % self.iter, loss.item(), iterations) iterations += 1 all_loss += loss.item() loss.backward() optimizer.step() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() desc = ( '[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (epoch, lr_scheduler.get_lr(optimizer), weight_decay, all_loss / (batch_idx + 1), 100. * correct / total, correct, total)) prog_bar.set_description(desc, refresh=True) test_loss, test_acc, top5_acc = self.test_model( testloader, criterion, device) self.logger.info( f'{epoch} Test Loss: %.3f, Test Top1 %.2f%%(test), Test Top5 %.2f%%(test).' % (test_loss, test_acc, top5_acc)) if test_acc > best_test_acc: best_test_loss = test_loss best_test_acc = test_acc network = self.config.network depth = self.config.depth dataset = self.config.dataset path = os.path.join( self.config.checkpoint, '%s_%s%s.pth.tar' % (dataset, network, depth)) save = { 'args': self.config, 'net': self.model, 'acc': test_acc, 'loss': test_loss, 'epoch': epoch } torch.save(save, path) print('** Finetuning finished. Stabilizing batch norm and test again!') stablize_bn(self.model, trainloader) test_loss, test_acc, top5_acc = self.test_model( testloader, criterion, device) best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss best_test_acc = max(test_acc, best_test_acc) return best_test_loss, best_test_acc
# init dataloader trainloader, testloader = get_dataloader(dataset=args.dataset, train_batch_size=args.batch_size, test_batch_size=256) # init optimizer and lr scheduler optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay) lr_schedule = { 0: args.learning_rate, int(args.epoch * 0.5): args.learning_rate * 0.1, int(args.epoch * 0.75): args.learning_rate * 0.01 } lr_scheduler = PresetLRScheduler(lr_schedule) # lr_scheduler = #StairCaseLRScheduler(0, args.decay_every, args.decay_ratio) # init criterion criterion = nn.CrossEntropyLoss() start_epoch = 0 best_acc = 0 if args.resume: print('==> Resuming from checkpoint..') assert os.path.isdir( 'checkpoint/pretrain'), 'Error: no checkpoint directory found!' checkpoint = torch.load('checkpoint/pretrain/%s_%s%s_bn_best.t7' % (args.dataset, args.network, args.depth)) net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc']