Пример #1
0
class ModelTrainer:
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 test_loader,
                 statspath,
                 scheduler=None,
                 batch_scheduler=False,
                 L1lambda=0):
        self.model = model
        self.scheduler = scheduler
        self.batch_scheduler = batch_scheduler
        self.optimizer = optimizer
        self.stats = ModelStats(model, statspath)
        self.train = Train(
            model, train_loader, optimizer, self.stats, self.scheduler
            if self.scheduler and self.batch_scheduler else None, L1lambda)
        self.test = Test(model, test_loader, self.stats)

    def run(self, epochs=10):
        pbar = tqdm_notebook(range(1, epochs + 1), desc="Epochs")
        for epoch in pbar:
            self.train.run()
            self.test.run()
            self.stats.next_epoch(
                self.scheduler.get_last_lr()[0] if self.scheduler else 0)
            pbar.write(self.stats.get_epoch_desc())
            if self.scheduler and not self.batch_scheduler:
                self.scheduler.step()
            if self.scheduler:
                pbar.write(
                    f"Learning Rate = {self.scheduler.get_last_lr()[0]:0.6f}")
        # save stats for later lookup
        self.stats.save()
Пример #2
0
class ModelTrainer:
    def __init__(self,
                 statsmanager,
                 model,
                 optimizer,
                 train_loader,
                 test_loader,
                 statspath,
                 scheduler=None,
                 batch_scheduler=False,
                 L1lambda=0):
        self.model = model
        self.statsmanager = statsmanager
        self.scheduler = scheduler
        self.batch_scheduler = batch_scheduler
        self.optimizer = optimizer
        self.stats = ModelStats(model, statspath)
        self.train = Train(statsmanager, model, train_loader, optimizer,
                           self.stats,
                           self.scheduler if self.batch_scheduler else None,
                           L1lambda)
        self.test = Test(model, test_loader, self.stats, statsmanager,
                         self.scheduler)

    def run(self, epochs=10):
        pbar = tqdm_notebook(range(1, epochs + 1), desc="Epochs")
        for epoch in pbar:
            self.train.run()
            self.test.run()
            lr = self.optimizer.param_groups[0]['lr']
            self.stats.next_epoch(lr)
            pbar.write(self.stats.get_epoch_desc())
            #self.statsmanager.append_lr(lr)
            self.statsmanager.append_train_loss(self.stats.avg_train_loss[-1])
            self.statsmanager.append_test_loss(self.stats.avg_test_loss[-1])
            self.statsmanager.append_test_accuracy(100 *
                                                   self.stats.test_acc[-1])
            self.statsmanager.append_train_accuracy(100 *
                                                    self.stats.train_acc[-1])
            if len(self.statsmanager.data['lr']) == 0:
                self.statsmanager.append_lr(self.stats.batch_lr[0])
            else:
                self.statsmanager.append_lr(self.stats.lr[-1])
            # need to ake it more readable and allow for other schedulers
            if self.scheduler and not self.batch_scheduler and not isinstance(
                    self.scheduler,
                    torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step()
            pbar.write(f"Learning Rate = {lr:0.6f}")
        # save stats for later lookup
        self.stats.save()
class ModelTrainer:
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 test_loader,
                 statspath,
                 criterion,
                 writer,
                 scheduler=None,
                 batch_scheduler=False,
                 L1lambda=0):
        self.model = model
        self.scheduler = scheduler
        self.criterion = criterion
        self.batch_scheduler = batch_scheduler
        self.optimizer = optimizer
        self.stats = ModelStats(model, statspath)
        self.train = Train(model, train_loader, optimizer, self.stats,
                           self.scheduler if self.batch_scheduler else None,
                           L1lambda, criterion)
        self.test = Test(model, test_loader, self.stats, writer,
                         self.scheduler, criterion)
        self.misclass = Misclass(model, test_loader, self.stats)
        self.test_loader = test_loader
        torch.backends.cudnn.benchmark = True

    def run(self, epochs=10):
        pbar = tqdm_notebook(range(1, epochs + 1), desc="Epochs")
        for epoch in pbar:
            self.train.run()
            self.test.run()
            lr = self.optimizer.param_groups[0]['lr']
            self.stats.next_epoch(lr)
            pbar.write(self.stats.get_epoch_desc())
            # need to ake it more readable and allow for other schedulers
            if self.scheduler and not self.batch_scheduler and not isinstance(
                    self.scheduler,
                    torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step()
            pbar.write(f"Learning Rate = {lr:0.6f}")
            print("printing results")
            printing_results(self.model, self.test_loader, epoch)

        # save stats for later lookup
        self.stats.save()