def _foreach_batch(dl: DataLoader, forward_fn: Callable[[Any], BatchResult], verbose=True, max_batches=None) -> EpochResult: """ Evaluates the given forward-function on batches from the given dataloader, and prints progress along the way. """ losses = [] num_correct = 0 num_samples = len(dl.sampler) num_batches = len(dl.batch_sampler) if max_batches is not None: if max_batches < num_batches: num_batches = max_batches num_samples = num_batches * dl.batch_size if verbose: pbar_file = sys.stdout else: pbar_file = open(os.devnull, 'w') pbar_name = forward_fn.__name__ with tqdm.tqdm(desc=pbar_name, total=num_batches, file=pbar_file) as pbar: dl_iter = iter(dl) for batch_idx in range(num_batches): data = next(dl_iter) batch_res = forward_fn(data) pbar.set_description(f'{pbar_name} ({batch_res.loss:.3f})') pbar.update() losses.append(batch_res.loss) num_correct += batch_res.num_correct avg_loss = sum(losses) / num_batches accuracy = 100. * num_correct / num_samples pbar.set_description(f'{pbar_name} ' f'(Avg. Loss {avg_loss:.3f}, ' f'Accuracy {accuracy:.1f})') return EpochResult(losses=losses, accuracy=accuracy)
def fit(self, dl_train: DataLoader, dl_test: DataLoader, num_epochs, checkpoints: str = None, early_stopping: int = None, print_every=1, post_epoch_fn=None, **kw) -> FitResult: """ Trains the model for multiple epochs with a given training set, and calculates validation loss over a given validation set. :param dl_train: Dataloader for the training set. :param dl_test: Dataloader for the test set. :param num_epochs: Number of epochs to train for. :param checkpoints: Whether to save model to file every time the test set accuracy improves. Should be a string containing a filename without extension. :param early_stopping: Whether to stop training early if there is no test loss improvement for this number of epochs. :param print_every: Print progress every this number of epochs. :param post_epoch_fn: A function to call after each epoch completes. :return: A FitResult object containing train and test losses per epoch. """ actual_num_epochs = 0 train_loss, train_acc, test_loss, test_acc = [], [], [], [] best_acc = None epochs_without_improvement = 0 checkpoint_filename = None if checkpoints is not None: checkpoint_filename = f'{checkpoints}.pt' Path(os.path.dirname(checkpoint_filename)).mkdir(exist_ok=True) if os.path.isfile(checkpoint_filename): print(f'*** Loading checkpoint file {checkpoint_filename}') saved_state = torch.load(checkpoint_filename, map_location=self.device) best_acc = saved_state.get('best_acc', best_acc) epochs_without_improvement =\ saved_state.get('ewi', epochs_without_improvement) self.model.load_state_dict(saved_state['model_state']) for epoch in range(num_epochs): save_checkpoint = False verbose = False # pass this to train/test_epoch. if epoch % print_every == 0 or epoch == num_epochs - 1: verbose = True self._print(f'--- EPOCH {epoch+1}/{num_epochs} ---', verbose) # TODO: Train & evaluate for one epoch # - Use the train/test_epoch methods. # - Save losses and accuracies in the lists above. # - Implement early stopping. This is a very useful and # simple regularization technique that is highly recommended. # ====== YOUR CODE: ====== save_checkpoint = True epoch_train_loss, epoch_train_acc = self.train_epoch(dl_train) # for loss in epoch_train_loss: train_loss += epoch_train_loss train_acc.append(epoch_train_acc) epoch_test_loss, epoch_test_acc = self.test_epoch(dl_test) # for loss in epoch_test_loss: test_loss += epoch_test_loss test_acc.append(epoch_test_acc) actual_num_epochs += 1 if best_acc is None or best_acc < epoch_test_acc: best_acc = epoch_test_acc epochs_without_improvement = 0 else: epochs_without_improvement += 1 if early_stopping is not None and early_stopping > 0: if epochs_without_improvement >= early_stopping: break train_result = EpochResult(losses=epoch_train_loss, accuracy=epoch_train_acc) test_result = EpochResult(losses=epoch_test_loss, accuracy=epoch_test_acc) # ======================== # Save model checkpoint if requested if save_checkpoint and checkpoint_filename is not None: saved_state = dict(best_acc=best_acc, ewi=epochs_without_improvement, model_state=self.model.state_dict()) torch.save(saved_state, checkpoint_filename) print(f'*** Saved checkpoint {checkpoint_filename} ' f'at epoch {epoch+1}') if post_epoch_fn: post_epoch_fn(epoch, train_result, test_result, verbose) return FitResult(actual_num_epochs, train_loss, train_acc, test_loss, test_acc)