def optimize(self, trainer: 'CallbackTrainer'): trainer.batch_grad_norm = training_util.rescale_gradients(trainer.model, self.grad_norm) trainer.optimizer.step() # Update the description with the latest metrics trainer.train_metrics.update( training_util.get_metrics(trainer.model, trainer.train_loss, trainer.batches_this_epoch) )
def train_on_batch(self, batch, optimizer): """ written by ph to keep interface between models consistent """ # to cuda batch['tokens']['tokens'] = batch['tokens']['tokens'].cuda() batch['verb_indicator'] = batch['verb_indicator'].cuda() batch['tags'] = batch['tags'].cuda() # forward + loss optimizer.zero_grad() output_dict = self(**batch) # input is dict[str, tensor] loss = output_dict["loss"] + self.get_regularization_penalty() if torch.isnan(loss): raise ValueError("nan loss encountered") # backward + update loss.backward() rescale_gradients(self, self.max_grad_norm) optimizer.step() return loss
def rescale_gradients(self) -> Optional[float]: return training_util.rescale_gradients(self.model, self._grad_norm)
def rescale_gradients(self, trainer: 'CallbackTrainer'): trainer.batch_grad_norm = training_util.rescale_gradients(trainer.model, self.grad_norm)
def _rescale_gradients(self) -> Optional[float]: """ Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. """ return training_util.rescale_gradients(self._model, self._grad_norm)
def _train_epoch(self, epoch: int) -> Dict[str, float]: """ Trains one epoch and returns metrics. """ logger.info("Epoch %d/%d", epoch, self._num_epochs) peak_cpu_usage = peak_memory_mb() logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}") gpu_usage = [] for gpu, memory in gpu_memory_mb().items(): gpu_usage.append((gpu, memory)) logger.info(f"GPU {gpu} memory usage MB: {memory}") train_loss = 0.0 # Set the model to "train" mode. self._pytorch_model.train() num_training_batches = [math.ceil( self.iterator.get_num_batches(train_data) / self._num_gradient_accumulation_steps ) for task, train_data in self.train_datas.items()] assert len(set(num_training_batches)) == 1, "num_training_batches doesn't agree" tasks = list(self.batch_group_generators.keys()) num_tasks = len(tasks) #if isinstance(self._learning_rate_scheduler, SlantedTriangular): # old_num_steps_per_epoch = self._learning_rate_scheduler.num_steps_per_epoch # self._learning_rate_scheduler.num_steps_per_epoch = num_training_batches[0] # logger.info(f"modify num_steps_per_epoch of lr scheduler from" # f"{old_num_steps_per_epoch} to {num_training_batches}") self._last_log = time.time() last_save_time = time.time() batches_this_epoch = 0 if self._batch_num_total is None: self._batch_num_total = 0 logger.info("Training") cumulative_batch_group_size = 0 tqdm_bar = Tqdm.tqdm(range(num_training_batches[0])) for _ in tqdm_bar: randperms = torch.randperm(len(tasks)).tolist() sampled_tasks = [tasks[idx] for idx in randperms[:self._tasks_per_step]] sampled_task_generators = [next(self.batch_group_generators[task]) for task in sampled_tasks] batches_this_epoch += 1 self._batch_num_total += 1 batch_num_total = self._batch_num_total self.optimizer.zero_grad() task_metrics = self.wrapper(tasks=sampled_task_generators, train=True, meta_train=True) losses = [list(map(lambda x: x["loss"], metrics)) for metrics in task_metrics] LASes = [list(map(lambda x: x["metric"]["LAS"], metrics)) for metrics in task_metrics] names = ["loss", "LAS"] list_values = [losses, LASes] if self.has_VIB: KLDivs = [list(map(lambda x: x["metric"]["kl_div"], metrics)) for metrics in task_metrics] names.append("KLDiv") list_values.append(KLDivs) if self.has_pos: pos_accs = [list(map(lambda x: x["metric"].get("pos_accuracy", 0.0), metrics)) for metrics in task_metrics] names.append("pos_acc") list_values.append(pos_accs) for name, values in zip(names, list_values): self._writer.log({f"step_{name}_{task}_{i}": value for task, task_values in zip(sampled_tasks, values) for i, value in enumerate(task_values)}, step=self._batch_num_total) values_inner_steps = list(map(np.mean, zip(*values))) self._writer.log({f"step_{name}_{i}": value for i, value in enumerate(values_inner_steps)}, step=self._batch_num_total) if name == "loss": train_loss += values_inner_steps[0] batch_grad_norm = self.rescale_gradients() # This does nothing if batch_num_total is None or you are using a # scheduler which doesn't update per batch. if self._learning_rate_scheduler: self._learning_rate_scheduler.step_batch(batch_num_total) if self._momentum_scheduler: self._momentum_scheduler.step_batch(batch_num_total) # variational information bottleneck / meta-learning without memorization if self.has_VIB: kl_loss, kl_div, kl_div2 = ContinuousVIB.get_kl_loss(self.model, sampled_task_generators) kl_loss.backward() self._writer.log({"kl_loss": kl_loss.detach().item(), "kl_div": kl_div, "kl_div2": kl_div2}, step=self._batch_num_total) # adversarial training if self.task_D and self.optim_D: # D training self.optimizer.step() steps_per_update = self.task_D.steps_per_update if (batch_num_total - 1) % steps_per_update == 0: self.optim_D.zero_grad() hidden_states, labels, masks = self.task_D.get_hidden_states( self.model, sampled_task_generators ) D_loss, _, acc = self.task_D(hidden_states, labels, masks, detach=True) D_loss.backward() disc_grad_norm = training_util.rescale_gradients(self.task_D, self.task_D.disc_grad_norm) self.optim_D.step() self._writer.log({"D_loss": D_loss.detach().item(), "D_acc": acc}, step=self._batch_num_total) if disc_grad_norm: self._writer.log({"D_grad_norm": disc_grad_norm.detach().item()}, step=self._batch_num_total) # G training hidden_states, labels, masks = self.task_D.get_hidden_states( self.model, sampled_task_generators ) _, g_loss, acc = self.task_D(hidden_states, labels, masks) if self.task_D.weight: alpha = self.task_D.weight else: alpha = self.task_D.get_alpha(self._batch_num_total, num_training_batches[0] * self._num_epochs) G_loss = -alpha * g_loss G_loss.backward() gen_grad_norm = training_util.rescale_gradients(self.model, self.task_D.gen_grad_norm) self._writer.log({"G_loss": g_loss.detach().item(), "alpha": alpha, "G_acc": acc}, step=self._batch_num_total) if gen_grad_norm: self._writer.log({"G_grad_norm": gen_grad_norm.detach().item()}, step=self._batch_num_total) self.optimizer.step() # Update moving averages if self._moving_average is not None: self._moving_average.apply(batch_num_total) # Update the description with the latest metrics metrics = training_util.get_metrics( self.wrapper.container, train_loss, batches_this_epoch, world_size=self._world_size, cuda_device=[self.cuda_device], ) # Updating tqdm only for the master as the trainers wouldn't have one if self._master: description = training_util.description_from_metrics(metrics) tqdm_bar.set_description(description, refresh=False) # log learning rate. self._writer.log({"lr": self.optimizer.param_groups[0]['lr']}, step=self._batch_num_total) # Save model if needed. if ( self._model_save_interval is not None and (time.time() - last_save_time > self._model_save_interval) and self._master ): last_save_time = time.time() self._save_checkpoint( "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time))) ) # Let all workers finish their epoch before computing # the final statistics for the epoch. if self._distributed: dist.barrier() metrics = training_util.get_metrics( self.wrapper.container, train_loss, batches_this_epoch, reset=True, world_size=self._world_size, cuda_device=[self.cuda_device], ) metrics["cpu_memory_MB"] = peak_cpu_usage for (gpu_num, memory) in gpu_usage: metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory return metrics