def test_regularization(self): penalty = self.model.get_regularization_penalty() assert penalty is None data_loader = PyTorchDataLoader(self.instances, batch_size=32) trainer = GradientDescentTrainer(self.model, None, data_loader) # optimizer, # You get a RuntimeError if you call `model.forward` twice on the same inputs. # The data and config are such that the whole dataset is one batch. training_batch = next(iter(data_loader)) validation_batch = next(iter(data_loader)) training_loss = trainer.batch_outputs(training_batch, for_training=True)["loss"].item() validation_loss = trainer.batch_outputs(validation_batch, for_training=False)["loss"].item() # Training loss should have the regularization penalty, but validation loss should not. numpy.testing.assert_almost_equal(training_loss, validation_loss)
def __call__(self, trainer: GradientDescentTrainer, metrics: Dict[str, Any], epoch: int, **kwargs): trainer.model.get_metrics(True) if epoch < 0: return # for moving_average in self.moving_averages: # moving_average.assign_average_value() with torch.no_grad(): logger.info("Testing") trainer.model.eval() batches_this_epoch = 0 val_loss = 0 bar = tqdm(self.test_iterator, desc="testing") for batch_group in bar: outs = trainer.batch_outputs(batch_group, for_training=False) loss = outs["loss"] if self.writer is not None: self.writer.write(outs) if loss is not None: # You shouldn't necessarily have to compute a loss for validation, so we allow for # `loss` to be None. We need to be careful, though - `batches_this_epoch` is # currently only used as the divisor for the loss function, so we can safely only # count those batches for which we actually have a loss. If this variable ever # gets used for something else, we might need to change things around a bit. batches_this_epoch += 1 val_loss += loss.detach().cpu().numpy() # Update the description with the latest metrics val_metrics = get_metrics(trainer.model, val_loss, val_loss, batches_this_epoch) description = description_from_metrics(val_metrics) if self.name is not None: description = "epoch: %d, dataset: %s, %s" % ( epoch, self.name, description) bar.set_description(description, refresh=False) trainer.val_metrics = get_metrics(trainer.model, val_loss, val_loss, batches_this_epoch, reset=False) if self.wandb_logger is not None: self.wandb_logger(trainer.val_metrics, epoch, prefix=self.name) # If the trainer has a moving average, restore # for moving_average in self.moving_averages: # moving_average.restore() if self.writer is not None: self.writer.set_epoch(epoch + 1) self.writer.reset() trainer.model.get_metrics(True)
def search_learning_rate( trainer: GradientDescentTrainer, start_lr: float = 1e-5, end_lr: float = 10, num_batches: int = 100, linear_steps: bool = False, stopping_factor: float = None, ) -> Tuple[List[float], List[float]]: """ Runs training loop on the model using [`GradientDescentTrainer`](../training/trainer.md#gradientdescenttrainer) increasing learning rate from `start_lr` to `end_lr` recording the losses. # Parameters trainer: `GradientDescentTrainer` start_lr : `float` The learning rate to start the search. end_lr : `float` The learning rate upto which search is done. num_batches : `int` Number of batches to run the learning rate finder. linear_steps : `bool` Increase learning rate linearly if False exponentially. stopping_factor : `float` Stop the search when the current loss exceeds the best loss recorded by multiple of stopping factor. If `None` search proceeds till the `end_lr` # Returns (learning_rates, losses) : `Tuple[List[float], List[float]]` Returns list of learning rates and corresponding losses. Note: The losses are recorded before applying the corresponding learning rate """ if num_batches <= 10: raise ConfigurationError( "The number of iterations for learning rate finder should be greater than 10." ) trainer.model.train() infinite_generator = itertools.cycle(trainer.data_loader) train_generator_tqdm = Tqdm.tqdm(infinite_generator, total=num_batches) learning_rates = [] losses = [] best = 1e9 if linear_steps: lr_update_factor = (end_lr - start_lr) / num_batches else: lr_update_factor = (end_lr / start_lr)**(1.0 / num_batches) for i, batch in enumerate(train_generator_tqdm): if linear_steps: current_lr = start_lr + (lr_update_factor * i) else: current_lr = start_lr * (lr_update_factor**i) for param_group in trainer.optimizer.param_groups: param_group["lr"] = current_lr trainer.optimizer.zero_grad() loss = trainer.batch_outputs(batch, for_training=True)["loss"] loss.backward() loss = loss.detach().cpu().item() if stopping_factor is not None and (math.isnan(loss) or loss > stopping_factor * best): logger.info( f"Loss ({loss}) exceeds stopping_factor * lowest recorded loss." ) break trainer.rescale_gradients() trainer.optimizer.step() learning_rates.append(current_lr) losses.append(loss) if loss < best and i > 10: best = loss if i == num_batches: break return learning_rates, losses