class ModelTrainer: def __init__(self, model, optimizer, train_loader, test_loader, statspath, scheduler=None, batch_scheduler=False, L1lambda=0): self.model = model self.scheduler = scheduler self.batch_scheduler = batch_scheduler self.optimizer = optimizer self.stats = ModelStats(model, statspath) self.train = Train( model, train_loader, optimizer, self.stats, self.scheduler if self.scheduler and self.batch_scheduler else None, L1lambda) self.test = Test(model, test_loader, self.stats) def run(self, epochs=10): pbar = tqdm_notebook(range(1, epochs + 1), desc="Epochs") for epoch in pbar: self.train.run() self.test.run() self.stats.next_epoch( self.scheduler.get_last_lr()[0] if self.scheduler else 0) pbar.write(self.stats.get_epoch_desc()) if self.scheduler and not self.batch_scheduler: self.scheduler.step() if self.scheduler: pbar.write( f"Learning Rate = {self.scheduler.get_last_lr()[0]:0.6f}") # save stats for later lookup self.stats.save()
class ModelTrainer: def __init__(self, statsmanager, model, optimizer, train_loader, test_loader, statspath, scheduler=None, batch_scheduler=False, L1lambda=0): self.model = model self.statsmanager = statsmanager self.scheduler = scheduler self.batch_scheduler = batch_scheduler self.optimizer = optimizer self.stats = ModelStats(model, statspath) self.train = Train(statsmanager, model, train_loader, optimizer, self.stats, self.scheduler if self.batch_scheduler else None, L1lambda) self.test = Test(model, test_loader, self.stats, statsmanager, self.scheduler) def run(self, epochs=10): pbar = tqdm_notebook(range(1, epochs + 1), desc="Epochs") for epoch in pbar: self.train.run() self.test.run() lr = self.optimizer.param_groups[0]['lr'] self.stats.next_epoch(lr) pbar.write(self.stats.get_epoch_desc()) #self.statsmanager.append_lr(lr) self.statsmanager.append_train_loss(self.stats.avg_train_loss[-1]) self.statsmanager.append_test_loss(self.stats.avg_test_loss[-1]) self.statsmanager.append_test_accuracy(100 * self.stats.test_acc[-1]) self.statsmanager.append_train_accuracy(100 * self.stats.train_acc[-1]) if len(self.statsmanager.data['lr']) == 0: self.statsmanager.append_lr(self.stats.batch_lr[0]) else: self.statsmanager.append_lr(self.stats.lr[-1]) # need to ake it more readable and allow for other schedulers if self.scheduler and not self.batch_scheduler and not isinstance( self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step() pbar.write(f"Learning Rate = {lr:0.6f}") # save stats for later lookup self.stats.save()
class ModelTrainer: def __init__(self, model, optimizer, train_loader, test_loader, statspath, criterion, writer, scheduler=None, batch_scheduler=False, L1lambda=0): self.model = model self.scheduler = scheduler self.criterion = criterion self.batch_scheduler = batch_scheduler self.optimizer = optimizer self.stats = ModelStats(model, statspath) self.train = Train(model, train_loader, optimizer, self.stats, self.scheduler if self.batch_scheduler else None, L1lambda, criterion) self.test = Test(model, test_loader, self.stats, writer, self.scheduler, criterion) self.misclass = Misclass(model, test_loader, self.stats) self.test_loader = test_loader torch.backends.cudnn.benchmark = True def run(self, epochs=10): pbar = tqdm_notebook(range(1, epochs + 1), desc="Epochs") for epoch in pbar: self.train.run() self.test.run() lr = self.optimizer.param_groups[0]['lr'] self.stats.next_epoch(lr) pbar.write(self.stats.get_epoch_desc()) # need to ake it more readable and allow for other schedulers if self.scheduler and not self.batch_scheduler and not isinstance( self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step() pbar.write(f"Learning Rate = {lr:0.6f}") print("printing results") printing_results(self.model, self.test_loader, epoch) # save stats for later lookup self.stats.save()