def train_loop(self, dataset): self.model.train() self.set_batch_size(len(dataset)) self.max_iter = self.get_num_iter(len(dataset), self.batch_size, self.iterations, self.epochs) epochs = self.iter_to_epochs(self.max_iter, len(dataset), self.batch_size) logger.info( f"{self.max_iter} iterations ~ {epochs:.2f} epochs / ({len(dataset)} examples / {self.batch_size} examples per batch)" ) optimizer = self.get_optimizer() lr_scheduler = self.get_learning_rate_scheduler(optimizer) progress_bar = tqdm( total=self.max_iter, bar_format= "num_examples={postfix[2]} {postfix[1][iter]}/{postfix[0]} loss={postfix[1][loss]:.4f}", postfix=[ self.max_iter, { "iter": 0, "loss": float('inf') }, len(dataset) ]) with progress_bar: iteration = 0 while True: train_iter = iter(self.batch_iterator(dataset, is_train=True)) iteration += 1 *inputs, labels = next(train_iter) train_loss = self.train_one_step(inputs, labels, optimizer, lr_scheduler) / len(labels) progress_bar.postfix[1].update({ "iter": iteration, "loss": train_loss }) progress_bar.update() if iteration == self.max_iter: break
def save(self, filename): logger.info(f"Saving model to {filename}") torch.save(self.model, filename)
def load(self, filename): logger.info(f"Loading model from {filename}") self.model = torch.load(str(filename), map_location=self.device)
def set_batch_size(self, train_size): if train_size < self.batch_size: logger.info(f"Setting batch size to {train_size}") self.batch_size = train_size