Esempio n. 1
0
def train_once(mb, net, trainloader, testloader, writer, config, ckpt_path, learning_rate, weight_decay, num_epochs,
               iteration, logger):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
    lr_schedule = {0: learning_rate,
                   int(num_epochs * 0.5): learning_rate * 0.1,
                   int(num_epochs * 0.75): learning_rate * 0.01}
    lr_scheduler = PresetLRScheduler(lr_schedule)
    best_acc = 0
    best_epoch = 0
    for epoch in range(num_epochs):
        train(net, trainloader, optimizer, criterion, lr_scheduler, epoch, writer, iteration=iteration)
        test_acc = test(net, testloader, criterion, epoch, writer, iteration)
        if test_acc > best_acc:
            print('Saving..')
            state = {
                'net': net,
                'acc': test_acc,
                'epoch': epoch,
                'args': config,
                # 'mask': mb.masks,
                # 'ratio': mb.get_ratio_at_each_layer()
            }
            path = os.path.join(ckpt_path, 'finetune_%s_%s%s_r%s_it%d_best.pth.tar' % (config.dataset,
                                                                                       config.network,
                                                                                       config.depth,
                                                                                       config.target_ratio,
                                                                                       iteration))
            torch.save(state, path)
            best_acc = test_acc
            best_epoch = epoch
    logger.info('Iteration [%d], best acc: %.4f, epoch: %d' %
                (iteration, best_acc, best_epoch))
Esempio n. 2
0
 def fine_tune_model(self, trainloader, testloader, criterion, optim, learning_rate, weight_decay, nepochs=10,
                     device='cuda'):
     self.model = self.model.train()
     optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
     # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4)
     lr_schedule = {0: learning_rate, int(nepochs * 0.5): learning_rate * 0.1,
                    int(nepochs * 0.75): learning_rate * 0.01}
     lr_scheduler = PresetLRScheduler(lr_schedule)
     best_test_acc, best_test_loss = 0, 100
     iterations = 0
     for epoch in range(nepochs):
         self.model = self.model.train()
         correct = 0
         total = 0
         all_loss = 0
         lr_scheduler(optimizer, epoch)
         desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
         lr_scheduler.get_lr(optimizer), 0, 0, correct, total))
         prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
         for batch_idx, (inputs, targets) in prog_bar:
             optimizer.zero_grad()
             inputs, targets = inputs.to(device), targets.to(device)
             outputs = self.model(inputs)
             loss = criterion(outputs, targets)
             self.writer.add_scalar('train_%d/loss' % self.iter, loss.item(), iterations)
             iterations += 1
             all_loss += loss.item()
             loss.backward()
             optimizer.step()
             _, predicted = outputs.max(1)
             total += targets.size(0)
             correct += predicted.eq(targets).sum().item()
             desc = ('[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                     (epoch, lr_scheduler.get_lr(optimizer), weight_decay, all_loss / (batch_idx + 1),
                      100. * correct / total, correct, total))
             prog_bar.set_description(desc, refresh=True)
         test_loss, test_acc = self.test_model(testloader, criterion, device)
         best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
         best_test_acc = max(test_acc, best_test_acc)
     print('** Finetuning finished. Stabilizing batch norm and test again!')
     stablize_bn(self.model, trainloader)
     test_loss, test_acc = self.test_model(testloader, criterion, device)
     best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
     best_test_acc = max(test_acc, best_test_acc)
     return best_test_loss, best_test_acc
Esempio n. 3
0
    def fine_tune_model(self,
                        trainloader,
                        testloader,
                        criterion,
                        optim,
                        learning_rate,
                        weight_decay,
                        nepochs=10,
                        device='cuda'):
        self.model = self.model.train()
        self.model = self.model.cpu()
        self.model = self.model.to(device)

        optimizer = optim.SGD(self.model.parameters(),
                              lr=learning_rate,
                              momentum=0.9,
                              weight_decay=weight_decay)
        # optimizer = optim.Adam(self.model.parameters(), weight_decay=5e-4)
        if self.config.dataset == "cifar10":
            lr_schedule = {
                0: learning_rate,
                int(nepochs * 0.5): learning_rate * 0.1,
                int(nepochs * 0.75): learning_rate * 0.01
            }

        elif self.config.dataset == "imagenet":
            lr_schedule = {
                0: learning_rate,
                30: learning_rate * 0.1,
                60: learning_rate * 0.01
            }

        lr_scheduler = PresetLRScheduler(lr_schedule)
        best_test_acc, best_test_loss = 0, 100
        iterations = 0

        for epoch in range(nepochs):
            self.model = self.model.train()
            correct = 0
            total = 0
            all_loss = 0
            lr_scheduler(optimizer, epoch)
            desc = ('[LR: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                    (lr_scheduler.get_lr(optimizer), 0, 0, correct, total))
            prog_bar = tqdm(enumerate(trainloader),
                            total=len(trainloader),
                            desc=desc,
                            leave=True)

            for batch_idx, (inputs, targets) in prog_bar:
                optimizer.zero_grad()
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                self.writer.add_scalar('train_%d/loss' % self.iter,
                                       loss.item(), iterations)
                iterations += 1
                all_loss += loss.item()
                loss.backward()
                optimizer.step()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                desc = (
                    '[%d][LR: %.5f, WD: %.5f] Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    %
                    (epoch, lr_scheduler.get_lr(optimizer), weight_decay,
                     all_loss /
                     (batch_idx + 1), 100. * correct / total, correct, total))
                prog_bar.set_description(desc, refresh=True)

            test_loss, test_acc, top5_acc = self.test_model(
                testloader, criterion, device)
            self.logger.info(
                f'{epoch} Test Loss: %.3f, Test Top1 %.2f%%(test), Test Top5 %.2f%%(test).'
                % (test_loss, test_acc, top5_acc))

            if test_acc > best_test_acc:
                best_test_loss = test_loss
                best_test_acc = test_acc
                network = self.config.network
                depth = self.config.depth
                dataset = self.config.dataset
                path = os.path.join(
                    self.config.checkpoint,
                    '%s_%s%s.pth.tar' % (dataset, network, depth))
                save = {
                    'args': self.config,
                    'net': self.model,
                    'acc': test_acc,
                    'loss': test_loss,
                    'epoch': epoch
                }
                torch.save(save, path)
        print('** Finetuning finished. Stabilizing batch norm and test again!')
        stablize_bn(self.model, trainloader)
        test_loss, test_acc, top5_acc = self.test_model(
            testloader, criterion, device)
        best_test_loss = best_test_loss if best_test_acc > test_acc else test_loss
        best_test_acc = max(test_acc, best_test_acc)
        return best_test_loss, best_test_acc
# init dataloader
trainloader, testloader = get_dataloader(dataset=args.dataset,
                                         train_batch_size=args.batch_size,
                                         test_batch_size=256)

# init optimizer and lr scheduler
optimizer = optim.SGD(net.parameters(),
                      lr=args.learning_rate,
                      momentum=0.9,
                      weight_decay=args.weight_decay)
lr_schedule = {
    0: args.learning_rate,
    int(args.epoch * 0.5): args.learning_rate * 0.1,
    int(args.epoch * 0.75): args.learning_rate * 0.01
}
lr_scheduler = PresetLRScheduler(lr_schedule)
# lr_scheduler = #StairCaseLRScheduler(0, args.decay_every, args.decay_ratio)

# init criterion
criterion = nn.CrossEntropyLoss()

start_epoch = 0
best_acc = 0
if args.resume:
    print('==> Resuming from checkpoint..')
    assert os.path.isdir(
        'checkpoint/pretrain'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('checkpoint/pretrain/%s_%s%s_bn_best.t7' %
                            (args.dataset, args.network, args.depth))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']