Ejemplo n.º 1
0
 def fine_tune_model(self, trainloader, testloader, criterion, optim, learning_rate, weight_decay, nepochs=10,
                     device='cuda'):
     self.model = self.model.train()
     optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
     # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4)
     lr_schedule = {0: learning_rate, int(nepochs * 0.5): learning_rate * 0.1,
                    int(nepochs * 0.75): learning_rate * 0.01}
     lr_scheduler = PresetLRScheduler(lr_schedule)
     best_test_acc, best_test_loss = 0, 100
     iterations = 0
     for epoch in range(nepochs):
         self.model = self.model.train()
         correct = 0
         total = 0
         all_loss = 0
         lr_scheduler(optimizer, epoch)
         desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
         lr_scheduler.get_lr(optimizer), 0, 0, correct, total))
         prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
         for batch_idx, (inputs, targets) in prog_bar:
             optimizer.zero_grad()
             inputs, targets = inputs.to(device), targets.to(device)
             outputs = self.model(inputs)
             loss = criterion(outputs, targets)
             self.writer.add_scalar('train_%d/loss' % self.iter, loss.item(), iterations)
             iterations += 1
             all_loss += loss.item()
             loss.backward()
             optimizer.step()
             _, predicted = outputs.max(1)
             total += targets.size(0)
             correct += predicted.eq(targets).sum().item()
             desc = ('[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                     (epoch, lr_scheduler.get_lr(optimizer), weight_decay, all_loss / (batch_idx + 1),
                      100. * correct / total, correct, total))
             prog_bar.set_description(desc, refresh=True)
         test_loss, test_acc = self.test_model(testloader, criterion, device)
         best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
         best_test_acc = max(test_acc, best_test_acc)
     print('** Finetuning finished. Stabilizing batch norm and test again!')
     stablize_bn(self.model, trainloader)
     test_loss, test_acc = self.test_model(testloader, criterion, device)
     best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
     best_test_acc = max(test_acc, best_test_acc)
     return best_test_loss, best_test_acc
Ejemplo n.º 2
0
    def fine_tune_model(self,
                        trainloader,
                        testloader,
                        criterion,
                        optim,
                        learning_rate,
                        weight_decay,
                        nepochs=10,
                        device='cuda'):
        self.model = self.model.train()
        self.model = self.model.cpu()
        self.model = self.model.to(device)

        optimizer = optim.SGD(self.model.parameters(),
                              lr=learning_rate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4)
        if self.config.dataset == "cifar10":
            lr_schedule = {
                0: learning_rate,
                int(nepochs * 0.5): learning_rate * 0.1,
                int(nepochs * 0.75): learning_rate * 0.01
            }

        elif self.config.dataset == "imagenet":
            lr_schedule = {
                0: learning_rate,
                30: learning_rate * 0.1,
                60: learning_rate * 0.01
            }

        lr_scheduler = PresetLRScheduler(lr_schedule)
        best_test_acc, best_test_loss = 0, 100
        iterations = 0

        for epoch in range(nepochs):
            self.model = self.model.train()
            correct = 0
            total = 0
            all_loss = 0
            lr_scheduler(optimizer, epoch)
            desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                    (lr_scheduler.get_lr(optimizer), 0, 0, correct, total))
            prog_bar = tqdm(enumerate(trainloader),
                            total=len(trainloader),
                            desc=desc,
                            leave=True)

            for batch_idx, (inputs, targets) in prog_bar:
                optimizer.zero_grad()
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                self.writer.add_scalar('train_%d/loss' % self.iter,
                                       loss.item(), iterations)
                iterations += 1
                all_loss += loss.item()
                loss.backward()
                optimizer.step()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                desc = (
                    '[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    %
                    (epoch, lr_scheduler.get_lr(optimizer), weight_decay,
                     all_loss /
                     (batch_idx + 1), 100. * correct / total, correct, total))
                prog_bar.set_description(desc, refresh=True)

            test_loss, test_acc, top5_acc = self.test_model(
                testloader, criterion, device)
            self.logger.info(
                f'{epoch} Test Loss: %.3f, Test Top1 %.2f%%(test), Test Top5 %.2f%%(test).'
                % (test_loss, test_acc, top5_acc))

            if test_acc > best_test_acc:
                best_test_loss = test_loss
                best_test_acc = test_acc
                network = self.config.network
                depth = self.config.depth
                dataset = self.config.dataset
                path = os.path.join(
                    self.config.checkpoint,
                    '%s_%s%s.pth.tar' % (dataset, network, depth))
                save = {
                    'args': self.config,
                    'net': self.model,
                    'acc': test_acc,
                    'loss': test_loss,
                    'epoch': epoch
                }
                torch.save(save, path)
        print('** Finetuning finished. Stabilizing batch norm and test again!')
        stablize_bn(self.model, trainloader)
        test_loss, test_acc, top5_acc = self.test_model(
            testloader, criterion, device)
        best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
        best_test_acc = max(test_acc, best_test_acc)
        return best_test_loss, best_test_acc