Esempio n. 1
0
    def optimize(self, trainer: 'CallbackTrainer'):
        trainer.batch_grad_norm = training_util.rescale_gradients(trainer.model, self.grad_norm)
        trainer.optimizer.step()

        # Update the description with the latest metrics
        trainer.train_metrics.update(
                training_util.get_metrics(trainer.model, trainer.train_loss, trainer.batches_this_epoch)
        )
Esempio n. 2
0
    def train_on_batch(self, batch, optimizer):
        """
        written by ph to keep interface between models consistent
        """
        # to cuda
        batch['tokens']['tokens'] = batch['tokens']['tokens'].cuda()
        batch['verb_indicator'] = batch['verb_indicator'].cuda()
        batch['tags'] = batch['tags'].cuda()

        # forward + loss
        optimizer.zero_grad()
        output_dict = self(**batch)  # input is dict[str, tensor]
        loss = output_dict["loss"] + self.get_regularization_penalty()
        if torch.isnan(loss):
            raise ValueError("nan loss encountered")

        # backward + update
        loss.backward()
        rescale_gradients(self, self.max_grad_norm)
        optimizer.step()

        return loss
Esempio n. 3
0
 def rescale_gradients(self) -> Optional[float]:
     return training_util.rescale_gradients(self.model, self._grad_norm)
Esempio n. 4
0
 def rescale_gradients(self, trainer: 'CallbackTrainer'):
     trainer.batch_grad_norm = training_util.rescale_gradients(trainer.model, self.grad_norm)
Esempio n. 5
0
 def _rescale_gradients(self) -> Optional[float]:
     """
     Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
     """
     return training_util.rescale_gradients(self._model, self._grad_norm)
Esempio n. 6
0
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs)
        peak_cpu_usage = peak_memory_mb()
        logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
        gpu_usage = []
        for gpu, memory in gpu_memory_mb().items():
            gpu_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self._pytorch_model.train()

        num_training_batches = [math.ceil(
            self.iterator.get_num_batches(train_data) / self._num_gradient_accumulation_steps
        ) for task, train_data in self.train_datas.items()]
        assert len(set(num_training_batches)) == 1, "num_training_batches doesn't agree"
        tasks = list(self.batch_group_generators.keys())
        num_tasks = len(tasks)

        #if isinstance(self._learning_rate_scheduler, SlantedTriangular):
        #    old_num_steps_per_epoch = self._learning_rate_scheduler.num_steps_per_epoch
        #    self._learning_rate_scheduler.num_steps_per_epoch = num_training_batches[0]
        #    logger.info(f"modify num_steps_per_epoch of lr scheduler from"
        #                f"{old_num_steps_per_epoch} to {num_training_batches}")

        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        logger.info("Training")

        cumulative_batch_group_size = 0
        tqdm_bar = Tqdm.tqdm(range(num_training_batches[0]))
        for _ in tqdm_bar:
            randperms = torch.randperm(len(tasks)).tolist()
            sampled_tasks = [tasks[idx] for idx in randperms[:self._tasks_per_step]]
            sampled_task_generators = [next(self.batch_group_generators[task]) for task in sampled_tasks]

            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self.optimizer.zero_grad()

            task_metrics = self.wrapper(tasks=sampled_task_generators, train=True, meta_train=True)

            losses = [list(map(lambda x: x["loss"], metrics)) for metrics in task_metrics]
            LASes = [list(map(lambda x: x["metric"]["LAS"], metrics)) for metrics in task_metrics]

            names = ["loss", "LAS"]
            list_values = [losses, LASes]

            if self.has_VIB:
                KLDivs = [list(map(lambda x: x["metric"]["kl_div"], metrics)) for metrics in task_metrics]
                names.append("KLDiv")
                list_values.append(KLDivs)

            if self.has_pos:
                pos_accs = [list(map(lambda x: x["metric"].get("pos_accuracy", 0.0), metrics)) for metrics in task_metrics]
                names.append("pos_acc")
                list_values.append(pos_accs)


            for name, values in zip(names, list_values):
                self._writer.log({f"step_{name}_{task}_{i}": value
                                  for task, task_values in zip(sampled_tasks, values)
                                  for i, value in enumerate(task_values)},
                                 step=self._batch_num_total)
                values_inner_steps = list(map(np.mean, zip(*values)))
                self._writer.log({f"step_{name}_{i}": value for i, value in
                                  enumerate(values_inner_steps)},
                                 step=self._batch_num_total)
                if name == "loss":
                    train_loss += values_inner_steps[0]

            batch_grad_norm = self.rescale_gradients()

            # This does nothing if batch_num_total is None or you are using a
            # scheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)
            if self._momentum_scheduler:
                self._momentum_scheduler.step_batch(batch_num_total)

            # variational information bottleneck / meta-learning without memorization
            if self.has_VIB:
                kl_loss, kl_div, kl_div2 = ContinuousVIB.get_kl_loss(self.model, sampled_task_generators)
                kl_loss.backward()
                self._writer.log({"kl_loss": kl_loss.detach().item(),
                                  "kl_div": kl_div,
                                  "kl_div2": kl_div2},
                                  step=self._batch_num_total)

            # adversarial training
            if self.task_D and self.optim_D:
                # D training
                self.optimizer.step()
                steps_per_update = self.task_D.steps_per_update
                if (batch_num_total - 1) % steps_per_update == 0:
                    self.optim_D.zero_grad()
                    hidden_states, labels, masks = self.task_D.get_hidden_states(
                        self.model,
                        sampled_task_generators
                    )
                    D_loss, _, acc = self.task_D(hidden_states, labels, masks, detach=True)
                    D_loss.backward()
                    disc_grad_norm = training_util.rescale_gradients(self.task_D, self.task_D.disc_grad_norm)
                    self.optim_D.step()
                    self._writer.log({"D_loss": D_loss.detach().item(),
                                      "D_acc": acc},
                                     step=self._batch_num_total)
                    if disc_grad_norm:
                        self._writer.log({"D_grad_norm": disc_grad_norm.detach().item()},
                                         step=self._batch_num_total)

                # G training
                hidden_states, labels, masks = self.task_D.get_hidden_states(
                    self.model,
                    sampled_task_generators
                )
                _, g_loss, acc = self.task_D(hidden_states, labels, masks)
                if self.task_D.weight:
                    alpha = self.task_D.weight
                else:
                    alpha = self.task_D.get_alpha(self._batch_num_total,
                                                  num_training_batches[0] * self._num_epochs)
                G_loss = -alpha * g_loss
                G_loss.backward()
                gen_grad_norm = training_util.rescale_gradients(self.model, self.task_D.gen_grad_norm)
                self._writer.log({"G_loss": g_loss.detach().item(), "alpha": alpha, "G_acc": acc},
                                 step=self._batch_num_total)
                if gen_grad_norm:
                    self._writer.log({"G_grad_norm": gen_grad_norm.detach().item()},
                                     step=self._batch_num_total)

            self.optimizer.step()

            # Update moving averages
            if self._moving_average is not None:
                self._moving_average.apply(batch_num_total)


            # Update the description with the latest metrics
            metrics = training_util.get_metrics(
                self.wrapper.container,
                train_loss,
                batches_this_epoch,
                world_size=self._world_size,
                cuda_device=[self.cuda_device],
            )

            # Updating tqdm only for the master as the trainers wouldn't have one
            if self._master:
                description = training_util.description_from_metrics(metrics)
                tqdm_bar.set_description(description, refresh=False)

            # log learning rate.
            self._writer.log({"lr": self.optimizer.param_groups[0]['lr']},
                             step=self._batch_num_total)

            # Save model if needed.
            if (
                self._model_save_interval is not None
                and (time.time() - last_save_time > self._model_save_interval)
                and self._master
            ):
                last_save_time = time.time()
                self._save_checkpoint(
                    "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time)))
                )

        # Let all workers finish their epoch before computing
        # the final statistics for the epoch.
        if self._distributed:
            dist.barrier()

        metrics = training_util.get_metrics(
            self.wrapper.container,
            train_loss,
            batches_this_epoch,
            reset=True,
            world_size=self._world_size,
            cuda_device=[self.cuda_device],
        )
        metrics["cpu_memory_MB"] = peak_cpu_usage
        for (gpu_num, memory) in gpu_usage:
            metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
        return metrics