コード例 #1
0
ファイル: Synthesiser.py プロジェクト: sshuster/IdiapTTS
    def run_wavenet_vocoder(synth_output, hparams):
        # Import ModelHandlerPyTorch here to prevent circular dependencies.
        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch

        assert hparams.synth_vocoder_path is not None, "Please set path to neural vocoder in hparams.synth_vocoder_path"
        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += '_' + hparams.synth_vocoder

        if not hasattr(hparams, 'bit_depth'):
            hparams.add_hparam("bit_depth", 16)

        synth_output = copy.copy(synth_output)

        input_fs_Hz = 1000.0 / hparams.frame_size_ms
        assert hasattr(hparams, "frame_rate_output_Hz") and hparams.frame_rate_output_Hz is not None, \
            "hparams.frame_rate_output_Hz has to be set and match the trained WaveNet."
        in_to_out_multiplier = hparams.frame_rate_output_Hz / input_fs_Hz
        # # dir_world_features = os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        input_gen = WorldFeatLabelGen(
            None,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32))
        # Load normalisation parameters for wavenet input.
        try:
            norm_params_path = os.path.splitext(
                hparams.synth_vocoder_path)[0] + "_norm_params.npy"
            input_gen.norm_params = np.load(norm_params_path).reshape(2, -1)
        except FileNotFoundError:
            logging.error(
                "Cannot find normalisation parameters for WaveNet input at {}."
                "Please save them there with numpy.save().".format(
                    norm_params_path))
            raise

        model_handler = ModelHandlerPyTorch()
        model_handler.model, *_ = model_handler.load_model(
            hparams.synth_vocoder_path, hparams, verbose=False)

        for id_name, output in synth_output.items():
            logging.info("Synthesise {} with {} vocoder.".format(
                id_name, hparams.synth_vocoder_path))

            # Any other post-processing could be done here.

            # Normalize input.
            output = input_gen.preprocess_sample(output)

            # output (T x C) --transpose--> (C x T) --unsqueeze(0)--> (B x C x T).
            output = output.transpose()[None, ...]
            # Wavenet input has to be (B x C x T).
            output, _ = model_handler.forward(
                output, hparams, batch_seq_lengths=(output.shape[-1], ))
            # output, _ = model_handler.forward(output[:, :, :1000], hparams, batch_seq_lengths=(1000,))  # DEBUG
            output = output[0].transpose(
            )  # Remove batch dim and transpose back to (T x C).

            out_channels = output.shape[1]
            if out_channels > 1:  # Check if the output is one-hot (quantized) or 1 (raw).
                # Revert mu-law quantization.
                output = output.argmax(axis=1)
                synth_output[
                    id_name] = RawWaveformLabelGen.mu_law_companding_reversed(
                        output, out_channels)

            # Save the audio.
            wav_file_path = os.path.join(
                hparams.synth_dir, "".join(
                    (os.path.basename(id_name).rsplit('.', 1)[0], "_",
                     hparams.model_name, hparams.synth_file_suffix, ".",
                     hparams.synth_ext)))
            Synthesiser.raw_to_file(wav_file_path, synth_output[id_name],
                                    hparams.synth_fs, hparams.bit_depth)

        # Restore identifier.
        hparams.setattr_no_type_check(
            "synth_file_suffix",
            old_synth_file_suffix)  # Can be None, thus no type check.
コード例 #2
0
class ModelTrainer(object):
    """
    Baseclass for all trainers.

    Load input and output data by generators (set by subclass). Perform normalisation and length mismatch fix.
    Select 5% of data (at least one) for testset. The model_handler implementation of a framework is used as interface.
    Subclasses have to set up the data and synthesize attributes by instantiating generators and possibly overwriting
    the synthesize method.
    """
    logger = logging.getLogger(__name__)

    # Default is hparams.out_dir/self.dir_extracted_acoustic_features, but can be overwritten by hparams.world_dir.
    dir_extracted_acoustic_features = "../WORLD/"

    def __init__(self, id_list, hparams):
        """Default constructor.

        :param id_list:              List or tuple of ids as strings. This list is separated into evaluation, test, and training set.
        :param hparams:              An object holding all hyper parameters.
        """

        self.logger.info("Running on host {}.".format(platform.node()))
        try:
            repo = git.Repo(search_parent_directories=True)
            self.logger.info("Git: {} at {}".format(repo.git_dir,
                                                    repo.head.object.hexsha))
        except git.exc.InvalidGitRepositoryError:
            pass
        try:
            framework_repo = git.Repo(path=os.environ['IDIAPTTS_ROOT'],
                                      search_parent_directories=True)
            self.logger.info("IdiapTTS framework git: {} at {}".format(
                framework_repo.git_dir, framework_repo.head.object.hexsha))
        except git.exc.InvalidGitRepositoryError:
            pass

        assert (hparams is not None)

        if not hasattr(hparams,
                       "batch_size_train") or not hparams.batch_size_train > 1:
            hparams.variable_sequence_length_train = False
        if not hasattr(hparams,
                       "batch_size_val") or not hparams.batch_size_val > 1:
            hparams.variable_sequence_length_val = False

        if hparams.use_gpu:
            if hparams.num_gpus > 1:
                os.environ['CUDA_VISIBLE_DEVICES'] = str(
                    tuple(range(hparams.num_gpus)))
            if ModelHandlerPyTorch.cuda_is_available():
                device_count = ModelHandlerPyTorch.device_count()
                if not device_count == hparams.num_gpus:
                    self.logger.error(
                        "Specified GPU count in hparams.num_gpus ({}) doesn't match hardware ({})."
                        .format(hparams.num_gpus, device_count))
                assert (device_count == hparams.num_gpus
                        )  # Specified GPU count doesn't match hardware.
            else:
                self.logger.warning(
                    "No CUDA device available, use CPU mode instead.")
                hparams.use_gpu = False

        if "lr" not in hparams.optimiser_args\
                and hasattr(hparams, "learning_rate")\
                and hparams.learning_rate is not None:  # Backwards compatibility.
            hparams.optimiser_args["lr"] = hparams.learning_rate

        if hparams.seed is not None:
            ModelHandlerPyTorch.seed(hparams.seed)  # Seed the backend.
            np.random.seed(hparams.seed)
            random.seed(hparams.seed)

        if not hasattr(self, "id_list_train") or self.id_list_train is None:
            id_list_shuffled = id_list
            if hparams.seed is not None:
                id_list_shuffled = random.sample(id_list, len(id_list))

            # Partition (randomly sorted) ids into [val_set, train_set, test_set].
            assert (hparams.test_set_perc + hparams.val_set_perc < 1)
            if hparams.val_set_perc > 0.0:
                num_valset = max(
                    1, int(len(id_list_shuffled) * hparams.val_set_perc))
                self.id_list_val = id_list_shuffled[:num_valset]
            else:
                num_valset = 0
                self.id_list_val = None
            if hparams.test_set_perc > 0.0:
                num_testset = max(
                    1, int(len(id_list_shuffled) * hparams.test_set_perc))
                self.id_list_test = id_list_shuffled[-num_testset:]
            else:
                num_testset = 0
                self.id_list_test = None
            self.id_list_train = id_list_shuffled[num_valset:-num_testset] if num_testset > 0\
                                                                           else id_list_shuffled[num_valset:]
            assert (len(self.id_list_train) > 0)

        # Create and initialize model.
        self.logger.info("Create ModelHandler.")
        self.model_handler = ModelHandlerPyTorch(
        )  # A handler for the NN models depending on the NN frameworks.

        # Data attributes.
        self.InputGen = None  # Used in the datasets.
        self.OutputGen = None  # Used in the datasets.
        self.dataset_train = None
        self.dataset_val = None

        self.batch_collate_fn = None  # Function used to combine samples to one batch.
        self.batch_decollate_fn = None  # Function used to split the batched output of a model.
        # Result is directly given to the gen_figure function.
        # Only the first element of the result is given to the post-processing function
        # of the OutputGen when result is a tuple or list.

        if not hasattr(
                self, "loss_function"
        ):  # Could have been set already in constructor of child class.
            self.loss_function = None  # Has to be defined by subclass.

        self.total_epoch = None  # Total number of epochs the current model was trained.

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyper-parameters. Parse non-default from given string."""
        return ExtendedHParams.create_hparams(hparams_string, verbose)

    # def plot_outputs(self, max_epochs, id_name, outputs, target):
    #     plotter = DataPlotter()
    #     net_name = os.path.basename(self.model_handler.model_name)
    #     filename = str(os.path.join(self.out_dir, id_name + '.' + net_name))
    #     plotter.set_title(id_name + " - " + net_name)
    #
    #     # Create a plot for every dimension of outputs with its target.
    #     graphs_o = [None] * target.shape[1]
    #     graphs_t = [None] * target.shape[1]
    #     for out_idx in range(0, target.shape[1]):
    #         graphs_o[out_idx] = list()
    #         # Add all outputs to the plot.
    #         for idx, o in enumerate(outputs):
    #             # Handle special case where NN output has only one dimension.
    #             if len(o.shape) == 1:
    #                 o = o.reshape(-1, 1)
    #             graphs_o[out_idx].append((o[:, out_idx], 'e' + str(min(max_epochs, (idx + 1) * self.epochs_per_plot))))
    #         # Give data to plotter and leave two grid position for each output dimension (used for output and target).
    #         plotter.set_data_list(grid_idx=out_idx * 2, data_list=graphs_o[out_idx])
    #
    #         # Add target belonging to the output dimension.
    #         graphs_t[out_idx] = list()
    #         graphs_t[out_idx].append((target[:, out_idx], 'target[' + str(out_idx) + ']'))
    #         plotter.set_data_list(grid_idx=out_idx * 2 + 1, data_list=graphs_t[out_idx])
    #
    #     # Set label for all.
    #     plotter.set_label(xlabel='frames', ylabel='amp')
    #
    #     # Generate and save the plot.
    #     plotter.gen_plot()
    #     plotter.save_to_file(filename + ".OUTPUTS.png")
    #
    #     plotter.plt.show()

    def init(self, hparams):

        self.logger.info(
            "CPU memory: " +
            str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3) +
            " MB.")
        if hparams.use_gpu:
            self.logger.info("GPU memory: " + str(get_gpu_memory_map()) +
                             " MB.")

        # Create the necessary directories.
        makedirs_safe(
            os.path.join(hparams.out_dir, hparams.networks_dir,
                         hparams.checkpoints_dir))

        # Create the default model path if not set or retrieve the name from the given path.
        if hparams.model_path is None:
            assert (
                hparams.model_name is not None
            )  # A model_path or model_name has to be given. No default exists.
            hparams.model_path = os.path.join(hparams.out_dir,
                                              hparams.networks_dir,
                                              hparams.model_name)
        elif hparams.model_name is None:
            hparams.model_name = os.path.basename(hparams.model_path)

        model_path_out = os.path.join(hparams.out_dir, hparams.networks_dir,
                                      hparams.model_name)
        if hparams.epochs <= 0:
            # Try to load the model. If it doesn't exist, create a new one and save it.
            # Return the loaded/created model, because no training was requested.
            try:
                self.total_epoch = self.model_handler.load_checkpoint(
                    hparams.model_path, hparams, hparams.optimiser_args["lr"]
                    if hasattr(hparams, "optimiser_args")
                    and "lr" in hparams.optimiser_args else None)
            except FileNotFoundError:
                if hparams.model_type is None:
                    self.logger.error(
                        "Model does not exist at {} and you didn't give model_type to create a new one."
                        .format(hparams.model_path))
                    raise  # This will rethrow the last exception.
                else:
                    self.logger.warning(
                        'Model does not exist at {}. Creating a new one instead and saving it.'
                        .format(hparams.model_path))
                    dim_in, dim_out = self.dataset_train.get_dims()
                    self.model_handler.create_model(hparams, dim_in, dim_out)
                    self.total_epoch = 0
                    self.model_handler.save_checkpoint(model_path_out,
                                                       self.total_epoch)

            self.logger.info("Model ready.")
            return

        if hparams.model_type is None:
            self.total_epoch = self.model_handler.load_checkpoint(
                hparams.model_path, hparams, hparams.optimiser_args["lr"]
                if hasattr(hparams, "optimiser_args")
                and "lr" in hparams.optimiser_args else None)
        else:
            dim_in, dim_out = self.dataset_train.get_dims()
            self.model_handler.create_model(hparams, dim_in, dim_out)
            self.total_epoch = 0

        self.logger.info("Model ready.")

    def train(self, hparams):
        """
        Train the model. Use generators for data preparation and model_handler for access.
        Generators have to be set in constructor of subclasses.

        :param hparams:          Hyper-parameter container.
        :return:                 A tuple of (all test loss, all training loss, the model_handler object).
        """

        hparams.verify(
        )  # Verify that attributes were added correctly, print warning for wrongly initialized ones.
        self.logger.info(hparams.get_debug_string())

        assert (self.model_handler
                )  # The init function has be called before training.

        # Skip training if epochs is not greater 0.
        if hparams.epochs <= 0:
            self.logger.info(
                "Number of training epochs is {}. Skipping training.".format(
                    hparams.epochs))
            return list(), list(), self.model_handler

        # Log evaluation ids.
        if len(self.id_list_val) > 0:
            valset_keys = sorted(self.id_list_val)
            self.logger.info(
                "Validation set (" + str(len(valset_keys)) + "): " + " ".join([
                    os.path.join(
                        os.path.split(os.path.dirname(id_name))[-1],
                        os.path.splitext(os.path.basename(id_name))[0])
                    for id_name in valset_keys
                ]))
        # Log test ids.
        testset_keys = sorted(self.id_list_test)
        self.logger.info(
            "Test set (" + str(len(testset_keys)) + "): " + " ".join([
                os.path.join(
                    os.path.split(os.path.dirname(id_name))[-1],
                    os.path.splitext(os.path.basename(id_name))[0])
                for id_name in testset_keys
            ]))

        # Setup the dataloaders.
        self.model_handler.set_dataset(hparams, self.dataset_train,
                                       self.dataset_val, self.batch_collate_fn)

        self.logger.info(
            "CPU memory: " +
            str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3) +
            " MB.")
        if hparams.use_gpu:
            self.logger.info("GPU memory: " + str(get_gpu_memory_map()) +
                             " MB.")

        # Run model.
        # if self.epochs_per_plot > 0:
        #     outputs = list()
        #     num_iterations = int(math.ceil(float(epochs) / float(self.epochs_per_plot)))
        #     epochs_per_iter = min(epochs, self.epochs_per_plot)
        #     for e in range(num_iterations):
        #         epochs_this_iter = min(epochs_per_iter, epochs - e * epochs_per_iter)
        #         nn_model.run(epochs_this_iter, e * epochs_per_iter)
        #         outputs.append(nn_model.forward(dict_input_labels[self.plot_per_epoch_id_name]))
        #     self.plot_outputs(epochs, self.plot_per_epoch_id_name, outputs, dict_output_labels[self.plot_per_epoch_id_name])

        # Some sanity checks.
        if hparams.epochs_per_scheduler_step:
            if hparams.epochs_per_test > hparams.epochs_per_scheduler_step:
                self.logger.warning(
                    "Model is validated only every {} epochs, ".format(
                        hparams.epochs_per_test) +
                    "but scheduler is supposed to run every {} epochs.".format(
                        hparams.epochs_per_scheduler_step))
            if hparams.epochs_per_test % hparams.epochs_per_scheduler_step != 0:
                self.logger.warning(
                    "hparams.epochs_per_test ({}) % hparams.epochs_per_scheduler_step ({}) != 0. "
                    .format(hparams.epochs_per_test,
                            hparams.epochs_per_scheduler_step) +
                    "Note that the scheduler is only run when current_epoch % "
                    +
                    "hparams.epochs_per_scheduler_step == 0. Therefore hparams.epochs_per_scheduler_step "
                    + "should be a factor of hparams.epochs_per_test.")

        t_start = timer()
        self.logger.info('Start training: {}'.format(
            datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

        self.model_handler.set_optimiser(hparams)
        self.model_handler.set_scheduler(
            hparams,
            self.total_epoch if hparams.use_saved_learning_rate else 0)

        assert (self.loss_function
                )  # Please set self.loss_function in the trainer construction.
        loss_function = self.loss_function.cuda(
        ) if hparams.use_gpu else self.loss_function

        all_loss = list(
        )  # List which is returned, containing all loss so that progress is visible.
        all_loss_train = list()
        best_loss = np.nan
        start_epoch = self.total_epoch

        # Compute error before first iteration.
        if hparams.start_with_test:
            self.logger.info('Test epoch [{}/{}]:'.format(
                start_epoch, start_epoch + hparams.epochs))
            loss, loss_features = self.model_handler.test(
                hparams, start_epoch, start_epoch, loss_function)
            all_loss_train.append(
                -1.0)  # Set a placeholder at the train losses.
            all_loss.append(loss)
            best_loss = loss  # Variable to save the current best loss.

        for current_epoch in range(1, hparams.epochs + 1):
            # Increment epoch number.
            self.total_epoch += 1

            # Train one epoch.
            self.logger.info('Train epoch [{}/{}]:'.format(
                self.total_epoch, start_epoch + hparams.epochs))
            train_loss = self.model_handler.train(hparams, self.total_epoch,
                                                  current_epoch, loss_function)
            all_loss_train.append(train_loss)
            if np.isnan(train_loss):
                break

            # Test if requested.
            if self.total_epoch % hparams.epochs_per_test == 0:
                self.logger.info('Test epoch [{}/{}]:'.format(
                    self.total_epoch, start_epoch + hparams.epochs))
                # Compute error on validation set.
                loss, loss_features = self.model_handler.test(
                    hparams, self.total_epoch, current_epoch, loss_function)

                # Save loss in a list which is returned.
                all_loss.append(loss)

                # Stop when loss is NaN. Reloading from checkpoint if necessary.
                if np.isnan(loss):
                    break

                # Save checkpoint if path is given.
                if hparams.out_dir is not None:
                    path_checkpoint = os.path.join(hparams.out_dir,
                                                   hparams.networks_dir,
                                                   hparams.checkpoints_dir)
                    # Check when to save a checkpoint.
                    if hparams.epochs_per_checkpoint > 0 and self.total_epoch % hparams.epochs_per_checkpoint == 0:
                        model_name = "{}-e{}-{}".format(
                            hparams.model_name, self.total_epoch,
                            loss_function)
                        self.model_handler.save_checkpoint(
                            os.path.join(path_checkpoint, model_name),
                            self.total_epoch)
                    # Always save best checkpoint with special name.
                    if loss < best_loss or np.isnan(best_loss):
                        best_loss = loss
                        model_name = hparams.model_name + "-best"
                        self.model_handler.save_checkpoint(
                            os.path.join(path_checkpoint, model_name),
                            self.total_epoch)

                # Run the scheduler if requested.
                if hparams.epochs_per_scheduler_step:
                    if (self.total_epoch if hparams.use_saved_learning_rate else current_epoch)\
                            % hparams.epochs_per_scheduler_step == 0:
                        self.model_handler.run_scheduler(
                            loss, self.total_epoch + 1)

        t_training = timer() - t_start
        self.logger.info('Training time: ' +
                         str(timedelta(seconds=t_training)))
        self.logger.info('Loss progress: ' + ', '.join('{:.4f}'.format(l)
                                                       for l in all_loss))
        self.logger.info('Loss train progress: ' +
                         ', '.join('{:.4f}'.format(l) for l in all_loss_train))

        if hparams.out_dir is not None:
            # Check if best model should be used as final model. Only possible when it was save in out_dir.
            if hparams.use_best_as_final_model:
                best_model_path = os.path.join(hparams.out_dir,
                                               hparams.networks_dir,
                                               hparams.checkpoints_dir,
                                               hparams.model_name + "-best")
                try:
                    self.total_epoch = self.model_handler.load_checkpoint(
                        best_model_path, hparams, hparams.optimiser_args["lr"]
                        if hparams.optimiser_args["lr"] else
                        hparams.learning_rate)
                    if self.model_handler.ema:  # EMA model should be used as best model.
                        self.model_handler.model = self.model_handler.ema.model
                        self.model_handler.ema = None  # Reset this one so that a new one is created for further training.
                        self.logger.info(
                            "Using best EMA model (epoch {}) as final model.".
                            format(self.total_epoch))
                    else:
                        self.logger.info(
                            "Using best (epoch {}) as final model.".format(
                                self.total_epoch))
                except FileNotFoundError:
                    self.logger.warning(
                        "No best model exists yet. Continue with the current one."
                    )

            # Save the model if requested.
            if hparams.save_final_model:
                self.model_handler.save_checkpoint(
                    os.path.join(hparams.out_dir, hparams.networks_dir,
                                 hparams.model_name), self.total_epoch)

        return all_loss, all_loss_train, self.model_handler

    @staticmethod
    def _input_to_str_list(input):
        # Checks for string input first.
        if isinstance(input, str):
            # Check if string is a path by trying to read ids from file.
            try:
                with open(input) as f:
                    id_list = f.readlines()
                # Trim entries in-place.
                id_list[:] = [s.strip(' \t\n\r') for s in id_list]
                return id_list
            except IOError:
                # String is single input id, convert to list.
                return [input]
        # Checks for list or tuple.
        elif isinstance(input, (list, tuple)):
            # Ensure elements are strings.
            return list(map(str, input))
        raise ValueError("Unkown input {} of type {}.".format(
            input, type(input)))

    @staticmethod
    def split_batch(output,
                    hidden,
                    seq_length_output=None,
                    permutation=None,
                    batch_first=False):
        """
        Retrieve output and hidden from batch.

        :param output:             Batched output tensor given by network.
        :param hidden:             Batched hidden tensor given by network.
        :param seq_length_output:  Tuple containing the lengths of all samples in the batch.
        :param permutation:        Permutations previously applied to the batch, which are reverted here.
        :param batch_first:        Batch dimension is first in output.
        :return:                   List of outputs and list of hidden, where each entry corresponds to one sample in the batch.
        """

        # Split the output of the batch.
        return ModelTrainer._split_return_values(output, seq_length_output, permutation, batch_first),\
               ModelTrainer._split_return_values(hidden, seq_length_output, permutation, batch_first)

    @classmethod
    def _split_return_values(cls, input_values, seq_length_output, permutation,
                             batch_first):
        if input_values is None:
            return None

        # Special case for bidirectional layers where the hidden state is a tuple.
        if isinstance(input_values, tuple):
            # If hidden is a tuple of None, return it directly.
            if all(v is None for v in input_values):
                return input_values

            # Split hidden states in their batch dimension.
            tuple_splitted = tuple(
                map(
                    lambda x: cls._split_return_values(
                        x, seq_length_output, permutation, batch_first),
                    input_values))

            # Now sort into each batch.
            return_values = list()
            batch_size = len([t for t in tuple_splitted if t is not None
                              ][0])  # Get batch size from not None element.

            for index in range(batch_size):
                batch = list()
                for element in tuple_splitted:
                    if element is None or (isinstance(element, tuple)
                                           and all(v is None
                                                   for v in element)):
                        batch.append(
                            element)  # Handles None and tuples of None.
                    else:
                        batch.append(element[index])
                return_values.append(tuple(batch))

            return tuple(return_values)

        if not isinstance(input_values, np.ndarray):
            cls.logger.error(
                "Expected numpy tensor but input is of type {}.".format(
                    type(input_values)))
            raise TypeError()

        # Return value is tensor.
        if batch_first:
            return_values = np.split(input_values,
                                     input_values.shape[0],
                                     axis=0)
            return_values = list(
                map(partial(np.squeeze, axis=0), return_values))
        else:
            return_values = np.split(input_values,
                                     input_values.shape[1],
                                     axis=1)
            return_values = list(
                map(partial(np.squeeze, axis=1), return_values))

        if seq_length_output is not None and len(seq_length_output) > 1:
            for idx in range(len(return_values)):
                return_values[idx] = return_values[
                    idx][:seq_length_output[idx]]

        if permutation is not None:
            return_values_unsorted = return_values.copy()
            for org_index, current_index in enumerate(permutation):
                return_values_unsorted[current_index] = return_values[
                    org_index]
            return_values = return_values_unsorted

        return return_values

    def forward(self, hparams, ids_input):
        """
        Forward all given ids through the network in batches of hparams.batch_size_val.

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, or one id.or None.
        :return:               (Dictionary of network outputs, dictionary of post-processed (by self.OutputGen) network outputs)
        """
        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.
        id_list = ModelTrainer._input_to_str_list(ids_input)

        self.logger.info("Start forwarding [{0}]".format(", ".join(
            str(i) for i in id_list)))
        t_start = timer()
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_val,
            load_target=False,
            synth=False,
            benchmark=False,
            gen_figure=False)
        t_training = timer() - t_start
        self.logger.info('Forwarding time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_output, model_output_post

    def synth(self, hparams, ids_input):
        """
        Synthesise all given ids with the self.synthesize function.

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, or one id.
        :return:               (Dictionary of network outputs, dictionary of post-processed (by self.OutputGen) network outputs)
        """

        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.
        assert (
            hparams.synth_dir is not None
        )  # Directory to store the generated audio files has to be set.
        makedirs_safe(hparams.synth_dir)
        id_list = ModelTrainer._input_to_str_list(ids_input)

        self.logger.info("Start synthesising [{0}]".format(", ".join(
            str(i) for i in id_list)))
        t_start = timer()
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_synth,
            load_target=False,
            synth=True,
            benchmark=False,
            gen_figure=hparams.synth_gen_figure)
        t_training = timer() - t_start
        self.logger.info('Synthesis time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_output, model_output_post

    def gen_figure(self, hparams, ids_input):
        """
        Generate figures for all given ids with the self.gen_figure_from_output function (has to be implemented).

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, or one id.
        :return:               (Dictionary of network outputs, dictionary of post-processed (by self.OutputGen) network outputs)
        """

        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.
        id_list = ModelTrainer._input_to_str_list(ids_input)

        self.logger.info("Start generating figures for [{0}]".format(", ".join(
            str(i) for i in id_list)))
        t_start = timer()
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_gen_figure,
            synth=False,
            benchmark=False,
            gen_figure=True)
        t_training = timer() - t_start
        self.logger.info('Figure generation time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_output, model_output_post

    def benchmark(self, hparams, ids_input=None):
        """
        Benchmark the currently loaded model using the self.compute_score function (has to be implemented).

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, one id, or None.
                               If ids_inputs=None benchmark on test set if not None, otherwise on validation set.
        :return:               Score(s).
        """

        assert (callable(getattr(self, 'compute_score', None))
                )  # Function has to be implemented for this trainer.
        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.

        # Select test or validation set when ids are not given explicitly.
        if ids_input is None:
            if self.id_list_test is not None and len(self.id_list_test) > 0:
                id_list = sorted(self.id_list_test)
                self.logger.info(
                    "Start benchmark on test set ({}): [{}]".format(
                        len(id_list), ", ".join(str(i) for i in id_list)))
            elif self.id_list_val is not None and len(self.id_list_val) > 0:
                id_list = sorted(self.id_list_val)
                self.logger.info(
                    "Start benchmark on validation set ({}): [{}]".format(
                        len(id_list), ", ".join(str(i) for i in id_list)))
            else:
                raise ValueError(
                    "No id list can be selected for benchmark, because non was given as parameter "
                    "and test and validation set are empty.")
        else:
            id_list = ModelTrainer._input_to_str_list(ids_input)
            self.logger.info(
                "Start benchmark on given input ({}): [{}]".format(
                    len(id_list), ", ".join(str(i) for i in id_list)))

        t_start = timer()
        model_scores = self._forward_batched(hparams,
                                             id_list,
                                             hparams.batch_size_benchmark,
                                             synth=False,
                                             benchmark=True,
                                             gen_figure=False)
        t_training = timer() - t_start
        self.logger.info('Benchmark time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_scores

    def _forward_batched(self,
                         hparams,
                         id_list,
                         batch_size,
                         load_target=True,
                         synth=False,
                         benchmark=False,
                         gen_figure=False):
        """
        Forward the features for the given ids in batches through the network.

        :param hparams:               Hyper-parameter container.
        :param id_list:               A list of ids for which the features are accessible by the self.InputGen object.
        :param batch_size:            Max size of a chunk of ids forwarded.
        :param load_target:           Give the target to the model when forwarded (used in teacher forcing).
        :param synth:                 Use the self.synthesize method to generate audio.
        :param benchmark:             Benchmark the given ids with the self.compute_score function.
        :param gen_figure:            Generate figures with the self.gen_figure_from_output function.
        :return:                      (Dictionary of outputs, dictionary of post-processed (by self.OutputGen) outputs)
        """

        self.logger.info("Get model outputs as batches of size {}.".format(
            min(batch_size, len(id_list))))
        dict_outputs = dict()
        dict_outputs_post = dict()
        dict_hiddens = dict()

        for batch_index in range(0, len(id_list), batch_size):
            batch_id_list = id_list[batch_index:batch_index + batch_size]

            inputs = list()
            for id_name in batch_id_list:
                # Load preprocessed sample and add it to the inputs with target value.
                inputs.append(
                    self.dataset_train.getitem_by_name(
                        id_name, load_target))  # No length check here.

            if self.batch_collate_fn is not None:
                batch_input_labels, batch_target_labels, seq_length_inputs, seq_length_output, *_, permutation = self.batch_collate_fn(
                    inputs,
                    common_divisor=hparams.num_gpus,
                    batch_first=hparams.batch_first)
            else:
                batch_input_labels, batch_target_labels, seq_length_inputs, seq_length_output, *_, permutation = self.model_handler.prepare_batch(
                    inputs,
                    common_divisor=hparams.num_gpus,
                    batch_first=hparams.batch_first)

            # Run forward pass of model.
            nn_output, nn_hidden = self.model_handler.forward(
                batch_input_labels, hparams, seq_length_inputs)

            # Retrieve output from batch.
            if self.batch_decollate_fn is not None:
                outputs, hiddens = self.batch_decollate_fn(
                    nn_output,
                    nn_hidden,
                    seq_length_output,
                    permutation,
                    batch_first=hparams.batch_first)
            else:
                outputs, hiddens = self.split_batch(
                    nn_output,
                    nn_hidden,
                    seq_length_output,
                    permutation,
                    batch_first=hparams.batch_first)

            # Post-process samples and generate a figure if requested.
            for idx, id_name in enumerate(batch_id_list):
                dict_outputs[id_name] = outputs[idx]
                dict_hiddens[
                    id_name] = hiddens[idx] if hiddens is not None else None

                # If output is a list or tuple use only the first element for post-processing.
                if isinstance(outputs[idx], tuple) or isinstance(
                        outputs[idx], list):
                    # Generate a figure if requested.
                    if gen_figure:
                        self.gen_figure_from_output(
                            id_name, outputs[idx][0],
                            hiddens[idx] if hiddens is not None else None,
                            hparams)
                    dict_outputs_post[
                        id_name] = self.dataset_train.postprocess_sample(
                            outputs[idx][0])
                else:
                    # Generate a figure if requested.
                    if gen_figure:
                        self.gen_figure_from_output(
                            id_name, outputs[idx],
                            hiddens[idx] if hiddens is not None else None,
                            hparams)
                    dict_outputs_post[
                        id_name] = self.dataset_train.postprocess_sample(
                            outputs[idx])

        if benchmark:
            # Implementation of compute_score is checked in benchmark function.
            return self.compute_score(dict_outputs_post, dict_hiddens, hparams)
        if synth:
            self.synthesize(id_list, dict_outputs_post, hparams)

        return dict_outputs, dict_outputs_post

    def gen_figure_from_output(self, id_name, output, hidden, hparams):
        raise NotImplementedError(
            "Class {} doesn't implement gen_figure_from_output(id_name, output, hidden, hparams)"
            .format(self.__class__.__name__))

    def synth_ref(self, hparams, file_id_list):
        if hparams.synth_vocoder == "WORLD":
            world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                          else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
            Synthesiser.synth_ref(hparams, file_id_list, world_dir)
            hparams.synth_file_suffix += str(hparams.num_coded_sps) + 'sp'
        else:
            Synthesiser.synth_ref(hparams, file_id_list)

    def synthesize(self, file_id_list, synth_output, hparams):

        # Create speaker subdirectories if necessary.
        for id_name in file_id_list:
            path_split = os.path.split(id_name)
            if len(path_split) > 2:
                makedirs_safe(os.path.join(hparams.synth_dir,
                                           *path_split[:-1]))

        if hparams.synth_vocoder == "WORLD":
            Synthesiser.run_world_synth(synth_output, hparams)
        # elif hparams.synth_vocoder == "STRAIGHT":  # Add further vocoders here.

        elif hparams.synth_vocoder == "r9y9wavenet_mulaw_16k_world_feats_English":
            Synthesiser.run_r9y9wavenet_mulaw_world_feats_synth(
                synth_output, hparams)

        elif hparams.synth_vocoder == "raw":
            # The features in the synth_output dictionary are raw waveforms and can be written directly to the file.
            Synthesiser.run_raw_synth(synth_output, hparams)

        elif hparams.synth_vocoder == "80_SSRN_English_GL":
            # Use a pre-trained spectrogram super resolution network for English and Griffin-Lim.
            # The features in the synth_output should be mfbanks.
            raise NotImplementedError()  # TODO

        elif hparams.synth_vocoder == "r9y9wavenet":
            # Synthesise with a pre-trained r9y9 WaveNet. The hyper-parameters have to match the model.
            Synthesiser.run_wavenet_vocoder(synth_output, hparams)