def process_dataloader(self,
                           dataloader: DataLoader,
                           hparams: ExtendedHParams,
                           total_epoch: int,
                           total_steps: int,
                           current_epoch: int = None,
                           training: bool = True):
        if hparams.use_gpu:
            assert (hparams.num_gpus <= torch.cuda.device_count()), \
                "Specified number of GPUs is incorrect."

        try:
            from torch.utils.tensorboard import SummaryWriter

            if hparams.has_value("tensorboard_dir"):
                tensorboard_dir = hparams.tensorboard_dir
            else:
                tensorboard_dir = os.path.join(hparams.out_dir,
                                               hparams.model_name,
                                               "tensorboard")
            tb_writer = SummaryWriter(log_dir=tensorboard_dir)
        except ImportError:
            tb_writer = None

        model = self.model
        if training:
            model.train()
            msg = "{}: Train with {} on ".format(
                datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                self.optimiser)
            if hparams.use_gpu:
                msg += str(torch.cuda.device_count()) + " GPU(s)."
            else:
                msg += "1 CPU."
            self.logger.info(msg),
        else:
            if self.ema is not None:
                self.logger.info("Using averaged model for validation.")
                model = self.ema.model
            model.eval()
            self.logger.info("{}: Compute loss of validation set.".format(
                datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

        if hparams.log_memory_consumption:
            self.logger.info('CPU: {:.0f} MB, GPU: {} MB'.format(
                resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3,
                str(get_gpu_memory_map()) if hparams.use_gpu else "-"))

        # Multi-GPU support.
        if hparams.num_gpus > 1:
            model = DataParallel(model, dim=0 if hparams.batch_first else 1)
            # Make the init_hidden method directly accessible.
            model.init_hidden = model.module.init_hidden

        # Log loss after each <hparams.logging_batch_index_perc>% of batches.
        logging_batch_index = (len(dataloader) // hparams.logging_batch_index_perc) + 1

        total_losses = dict()

        # for params in reversed(list(self.model.parameters())):
        #         params.retain_grad()

        for batch_index, batch in enumerate(dataloader):

            if hparams.use_gpu:
                batch = self._batch_to_gpu(batch, hparams.dataset_load_async)

            data_dict, lengths = batch
            batch_size = len(next(iter(lengths.values())))
            model.init_hidden(batch_size)

            # Compute max length because DataParallel splits the seq_lengths_input and padding will be done according to
            # the maximum length of that subset. Combining multi GPU output will fail with a size miss match.
            # https://pytorch.org/docs/stable/notes/faq.html#pack-rnn-unpack-with-data-parallelism
            max_lengths = dict()
            for key in data_dict.keys():
                if key in lengths:
                    l_max = max(lengths[key])
                    if hparams.use_gpu and hparams.num_gpus > 1:
                        l_max = l_max.repeat(hparams.num_gpus)
                    max_lengths[key] = l_max

            # Give max length because DataParallel splits the seq_lengths_input and padding will be done according to
            # the maximum length of that subset. Combining multi GPU output will fail with a size miss match.
            # https://pytorch.org/docs/stable/notes/faq.html#pack-rnn-unpack-with-data-parallelism
            if training:
                model(data_dict, lengths, max_lengths)
            else:
                with torch.no_grad():
                    model(data_dict, lengths, max_lengths)

            losses = {}
            for loss_fn in self.losses:
                loss_ = loss_fn(data_dict, lengths, total_steps)
                for loss_name, l in loss_.items():
                    if torch.isnan(l):
                        raise ValueError("Found NaN in {} loss.".format(loss_name))
                    if not hparams.replace_inf_grads_by_zero and torch.isinf(l):
                        raise ValueError("Found +/-Inf in {} loss.".format(loss_name))
                    if loss_name in losses:
                        raise KeyError("Loss with name {} defined twice.".format(loss_name))
                    losses[loss_name] = l
            backprop_loss = self.get_summed_losses_subset(
                loss_names=hparams.backprop_loss_names, losses=losses)
            if hparams.backprop_loss_names is None \
                    and hparams.scheduler_loss_names is None:
                scheduler_loss = backprop_loss.detach()
            else:
                scheduler_loss = self.get_summed_losses_subset(
                    loss_names=hparams.scheduler_loss_names, losses=losses).detach()

            if training:
                self.optimiser.zero_grad()
                backprop_loss.backward(retain_graph=hparams.backward_retain_graph)
                total_steps += 1

                # for params in reversed(list(self.model.parameters())):
                #     nan_or_inf |= torch.isnan(params.grad).any()
                #     nan_or_inf |= (params.grad == float("inf")).any()
                #     nan_or_inf |= (params.grad == -float("inf")).any()
                #     if nan_or_inf:
                #         raise ValueError("Found NaN/Inf in {}.".format(params))
                #         pdb.set_trace()

                if hparams.replace_inf_grads_by_zero:
                    self._replace_inf_grads_by_zero()

                if hparams.grad_clip_norm_type is not None:
                    # Adds a small bias.
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   hparams.grad_clip_max_norm,
                                                   hparams.grad_clip_norm_type)
                if hparams.grad_clip_thresh is not None:
                    # Adds a big bias.
                    torch.nn.utils.clip_grad_value_(self.model.parameters(),
                                                    hparams.grad_clip_thresh)

                self.optimiser.step()

                # Update exponential moving average.
                if self.ema:
                    self.ema.update_params(model)

                current_iter = self._get_current_iteration(
                    batch_index=batch_index, current_epoch=current_epoch,
                    dataloader_length=len(dataloader), hparams=hparams,
                    total_epoch=total_epoch)
                self.run_scheduler(hparams=hparams, loss=scheduler_loss,
                                   current_iter=current_iter)

            # Logging current error.
            if batch_index % logging_batch_index == 0:
                log_message = "Train " if training else "Test "
                log_message += "mini batch [{:{front_pad}d}/{}]".format(
                    batch_index + 1, len(dataloader),
                    front_pad=len(str(len(dataloader))))
                log_message += "\tLoss: "
                log_message += " ".join(["{}: {:.3f}".format(key, loss) for
                                         key, loss in losses.items()])
                if hparams.log_memory_consumption:
                    log_message += "\tCPU: {:.0f} MB, ".format(
                        resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3)
                    if hparams.use_gpu:
                        log_message += "GPU: {} MB".format(
                            str(get_gpu_memory_map()))

                self.logger.info(log_message)

            losses = {k: l.detach() for k, l in losses.items()}
            for key, loss in losses.items():
                if key not in total_losses:
                    total_losses[key] = loss
                else:
                    total_losses[key] += loss

            if tb_writer is not None:
                tb_writer.add_scalars("Train loss", losses, total_steps)

            del data_dict, lengths, max_lengths, losses, backprop_loss, scheduler_loss

        total_losses = {key: value / len(dataloader) for key, value in total_losses.items()}

        if not training:
            if tb_writer is not None:
                tb_writer.add_scalars("Validation loss", total_losses, total_steps)

            self.logger.info(
                'Validation set: Total loss: {}\nAverage loss:\n\t{}\n'.format(
                    sum(total_losses.values()),
                    "\n\t".join(["{}: {:.3f}".format(key, loss)
                                 for key, loss in total_losses.items()])))

            fn_log_per_test = getattr(self.model, "log_per_test", None)
            if callable(fn_log_per_test):
                fn_log_per_test()

        np_total_losses = {key: loss.cpu().numpy() for key, loss in total_losses.items()}
        del total_losses

        return np_total_losses
示例#2
0
    def process_dataloader(self, dataloader, hparams, training=True):
        """
        Train or test the model by loading batches from the dataloader.

        :param dataloader:        Dataloader of the train/test set.
        :param hparams:           Hyper-parameter container.
        :param training:          Determines if it runs the training or testing loop.
        :return:                  Tuple of total loss and total loss per output feature.
        """

        model = self.model
        if training:
            model.train()
            self.logger.info("{}: Train with {} on {}".format(
                datetime.now(), self.optimiser,
                str(torch.cuda.device_count()) +
                " GPU(s)." if hparams.use_gpu else "1 CPU."))
        else:
            self.logger.info(
                str(datetime.now()) + ": Compute loss of validation set.")
            if self.ema is not None:
                self.logger.info("Using averaged model for validation.")
                model = self.ema.get_averaged_model(self.model)
            model.eval()

        if hparams.log_memory_consumption:
            self.logger.info('CPU: {:.0f} MB, GPU: {} MB'.format(
                resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3,
                str(get_gpu_memory_map()) if hparams.use_gpu else "-"))

        # Multi-GPU support.
        if hparams.num_gpus > 1:
            model = DataParallel(model, dim=0 if hparams.batch_first else 1)
            model.init_hidden = model.module.init_hidden  # Make the init_hidden method directly accessible.

        # Log loss after each <hparams.logging_batch_index_perc>% of batches.
        logging_batch_index = (len(dataloader) //
                               hparams.logging_batch_index_perc) + 1

        current_batch_index = -1  # Consider special case for first batch.
        current_batch = None
        loss = None
        total_loss = 0
        loss_features = None

        # FIXME: Experimental implementation to pre-load the next batch to GPU. Does not work yet because it blocks.
        if hparams.use_gpu and hparams.preload_next_batch_to_gpu:
            # Create an iterator around the dataloader to pop the first element before the for loop.
            dataloader = iter(dataloader)
            # Pop the first batch from the dataloader.
            current_batch = next(dataloader)
            current_batch_index = 0
            # Move the first batch to GPU.
            inputs, target, seq_lengths_input, seq_lengths_target, mask, _ = current_batch
            inputs = inputs.cuda(async=hparams.dataset_load_async
                                 ) if inputs is not None else None
            seq_lengths_input = seq_lengths_input.cuda(
                async=hparams.dataset_load_async)
            target = target.cuda(async=hparams.dataset_load_async)
            seq_lengths_target = seq_lengths_target.cuda(
                async=hparams.dataset_load_async)
            mask = mask.cuda(
                async=hparams.dataset_load_async) if mask is not None else None
            current_batch = inputs, target, seq_lengths_input, seq_lengths_target, mask, _