Exemplo n.º 1
0
    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        # Store parameters.
        self.use_gpu = hparams.use_gpu
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dropout = hparams.dropout

        assert(not hparams.batch_first)  # This implementation doesn't work with batch_first.

        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch
        self.model_handler_atoms = ModelHandlerPyTorch()
        if hasattr(hparams.hparams_atom, "learning_rate"):
            lr = hparams.hparams_atom.learning_rate
        elif hasattr(hparams.hparams_atom, "optimiser_args"):
            lr = hparams.hparams_atom.optimiser_args["lr"]
        elif hasattr(hparams, "learning_rate"):
            lr = hparams.learning_rate
        elif hasattr(hparams.optimiser_args, "lr"):
            lr = hparams.optimiser_args["lr"]
        else:
            lr = None
        self.model_handler_atoms.load_checkpoint(hparams.atom_model_path, hparams.hparams_atom, initial_lr=lr)
        self.add_module("atom_model", self.model_handler_atoms.model)  # Add atom model as submodule so that parameters are properly registered.

        if hparams.complex_poles:
            self.intonation_filters = ComplexModel(hparams.thetas, hparams.phase_init)
        else:
            self.intonation_filters = CriticalModel(hparams.thetas)
        self.add_module("intonation_filters", self.intonation_filters)
Exemplo n.º 2
0
    def test_save_load_equality(self):
        hparams = ModelTrainer.create_hparams()
        hparams.out_dir = os.path.join(
            self.out_dir,
            "test_save_load_equality")  # Add function name to path.
        model_path = os.path.join(hparams.out_dir, "test_model.nn")

        # Create a new model and save it.
        dim_in, dim_out = 10, 4
        total_epochs = 10
        model_handler = ModelHandlerPyTorch()
        model_handler.model = torch.nn.Sequential(
            torch.nn.Linear(dim_in, dim_out))
        model_handler.save_checkpoint(model_path, total_epochs)

        # Create a new model handler and test load save.
        hparams.model_type = None
        model_handler = ModelHandlerPyTorch()
        saved_total_epochs = model_handler.load_checkpoint(model_path, hparams)
        self.assertEqual(total_epochs,
                         saved_total_epochs,
                         msg="Saved and loaded total epochs do not match")
        model_copy_path = os.path.join(hparams.out_dir, "test_model_copy.nn")
        model_handler.save_checkpoint(model_copy_path, total_epochs)

        # self.assertTrue(filecmp.cmp(model_path, model_copy_path, False))  # This does not work.
        self.assertTrue(equal_checkpoint(model_path, model_copy_path),
                        "Loaded and saved models are not the same.")

        shutil.rmtree(hparams.out_dir)
Exemplo n.º 3
0
    def prepare_batch(batch,
                      common_divisor=1,
                      batch_first=False,
                      use_cond=True,
                      one_hot_target=True):
        inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation = ModelHandler.prepare_batch(
            batch, common_divisor=common_divisor, batch_first=batch_first)

        if batch_first:
            # inputs: (B x T x C) --permute--> (B x C x T)
            inputs = inputs.transpose(1, 2).contiguous()
        # TODO: Handle case where batch_first=False: inputs = inputs.transpose(2, 0, 1).contiguous()?

        if targets is not None:
            if batch_first:
                # targets: (B x T x C) --permute--> (B x C x T)
                targets = targets.transpose(1, 2).contiguous()

            if not one_hot_target:
                targets = targets.max(dim=1, keepdim=True)[1].float()

        if mask is not None:
            mask = mask[:, 1:].contiguous()

        return inputs if use_cond else None, targets, seq_lengths_input, seq_lengths_output, mask, permutation
Exemplo n.º 4
0
    def _load_pre_net(self, hparams):
        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch
        from idiaptts.src.model_trainers.ModelTrainer import ModelTrainer

        model_path = ModelTrainer.get_model_path(hparams)
        self.pre_net, *_ = ModelHandlerPyTorch.load_model(model_path,
                                                          hparams,
                                                          verbose=True)
    def prepare_batch(batch, common_divisor=1, batch_first=False):
        inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation = ModelHandler.prepare_batch(
            batch, common_divisor=common_divisor, batch_first=batch_first)

        if mask is None:
            mask = torch.ones((seq_lengths_output[0], 1, 1))
        mask = mask.expand(*mask.shape[:2], 2)
        # mask: T x B x 2 (lf0, vuv), add L1 error dimension.
        mask = torch.cat((mask, mask[..., -1:]), dim=-1).contiguous()

        # TODO: This is a dirty hack, it works but only for VUV weight of 0 (it completes the loss function WMSELoss).
        mask[..., 0] = mask[..., 0] * seq_lengths_output.float()
        ################################################

        return inputs, targets, seq_lengths_input, seq_lengths_output, mask, permutation
Exemplo n.º 6
0
    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        # Store parameters.
        self.use_gpu = hparams.use_gpu
        self.dim_in = dim_in
        self.dim_out = dim_out
        norm_params_dim = hparams.num_coded_sps * (3 if hparams.add_deltas else
                                                   1)
        self.mean = nn.Parameter(
            torch.zeros(norm_params_dim),
            requires_grad=False)  # TODO: Should not appear in state_dict.
        self.std_dev = nn.Parameter(torch.ones(norm_params_dim),
                                    requires_grad=False)
        # self.dropout = hparams.dropout
        self.batch_first = hparams.batch_first
        self.batch_dim = 0 if hparams.batch_first else 1
        self.time_dim = 1 if hparams.batch_first else 0

        # Create hparams for pre-net.
        self.hparams_prenet = copy.deepcopy(hparams)
        self.hparams_prenet.model_type = hparams.pre_net_model_type
        self.hparams_prenet.model_name = hparams.pre_net_model_name
        self.hparams_prenet.model_path = hparams.pre_net_model_path
        # Remove embedding functions when they should not been passed.
        if not hparams.pass_embs_to_pre_net:
            self.hparams_prenet.f_get_emb_index = None

        # Create pre-net from type if not None, or try to load it by given path, or default path plus name.
        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch
        self.model_handler_prenet = ModelHandlerPyTorch()
        if self.hparams_prenet.model_type is not None:
            prenet_dim_in = (dim_in[0] - (0 if not hparams.f_get_emb_index
                                          or hparams.pass_embs_to_pre_net else
                                          len(hparams.f_get_emb_index)),
                             *dim_in[1:])
            self.model_handler_prenet.create_model(self.hparams_prenet,
                                                   prenet_dim_in, dim_out)
        elif self.hparams_prenet.model_path is not None:
            self.model_handler_prenet.model, *_ = self.model_handler_prenet.load_model(
                self.hparams_prenet.model_path,
                self.hparams_prenet,
                verbose=True)
        elif self.hparams_prenet.model_name is not None:
            self.hparams_prenet.model_path = os.path.join(
                self.hparams_prenet.out_dir, self.hparams_prenet.networks_dir,
                self.hparams_prenet.model_name)
            self.model_handler_prenet.model, *_ = self.model_handler_prenet.load_model(
                self.hparams_prenet.model_path,
                self.hparams_prenet,
                verbose=True)
        else:
            self.logger.warning("No pre-net specified.")

        if self.model_handler_prenet.model is not None:
            self.model_handler_prenet.model.save_intermediate_outputs = True  # Used by RNNDyn.
            self.add_module("pre_net", self.model_handler_prenet.model
                            )  # Properly register parameters of submodule.
            if not hparams.train_pre_net:
                for param in self.model_handler_prenet.model.parameters():
                    param.requires_grad = False

        self.prenet_group_index_of_alpha = -2
        self.embedding_dim = hparams.speaker_emb_dim
        if hparams.num_speakers is None:
            self.logger.warning(
                "Number of speaker is not defined. Assume only one speaker for embedding."
            )
            self.num_speakers = 1
        else:
            self.num_speakers = hparams.num_speakers
        self.pass_embs_to_pre_net = hparams.pass_embs_to_pre_net

        self.n = hparams.num_coded_sps
        self.alpha_range = 0.2
        self.has_deltas = hparams.add_deltas
        self.max_polynomial = 2 * self.n

        # Reuse pre-net embeddings or create new ones if non exist yet.
        if not hparams.pass_embs_to_pre_net or not self.model_handler_prenet.model:
            self.embeddings = nn.Embedding(self.num_speakers,
                                           self.embedding_dim)
        else:
            self.embeddings = self.model_handler_prenet.model.emb_groups[0]

        # Attach alpha layer to selected pre-net layer.
        if self.model_handler_prenet.model is not None:
            pre_net_layer_group = self.model_handler_prenet.model.layer_groups[
                self.prenet_group_index_of_alpha]
            self.alpha_layer = nn.Linear(
                pre_net_layer_group.out_dim *
                (2 if pre_net_layer_group.is_rnn else 1) + self.embedding_dim,
                1)
        else:
            self.alpha_layer = nn.Linear(
                np.prod(dim_in) + self.embedding_dim, 1)

        # self.alpha_layer = nn.Linear(53, 1)
        # self.alpha_layer = nn.Linear(self.embedding_dim, 1)

        # self.all_pass_warp_matrix = None
        # self.precision = 100
        # self.eps = 1e-45  # float(np.finfo(np.float32).eps)
        # self.pre_compute_warp_matrices(self.precision, requires_recursive_grad=True)

        self.computation_dtype = 'torch.FloatTensor'  # torch.float32 cannot be pickled.
        self.w_matrix_3d = self.gen_w_matrix_3d()

        # self.index_vec_pos = torch.arange(0, 2 * self.n, dtype=self.computation_dtype)
        # index_vec_neg_sign = torch.tensor([v * pow(-1, i) for i, v in enumerate(range(0, 2 * self.n))],
        #                                   dtype=self.computation_dtype, requires_grad=False).sign()
        # index_vec_neg_sign[0] = 1.
        #
        # self.w_matrix_3d_sign = self.w_matrix_3d.sign().type(self.computation_dtype)
        # self.w_matrix_3d_sign = torch.stack((self.w_matrix_3d_sign, self.w_matrix_3d_sign * index_vec_neg_sign[None, None, :]))
        # self.w_matrix_3d_log = torch.log(self.w_matrix_3d.abs()).type(self.index_vec_pos.dtype)

        self.w_matrix_3d = self.w_matrix_3d.type(self.computation_dtype)

        # self.compare_with_recursive(self.alpha_range)
        if self.use_gpu:
            self.w_matrix_3d = self.w_matrix_3d.cuda()
Exemplo n.º 7
0
class WarpingLayer(nn.Module):
    IDENTIFIER = "VTLN"
    logger = logging.getLogger(__name__)

    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        # Store parameters.
        self.use_gpu = hparams.use_gpu
        self.dim_in = dim_in
        self.dim_out = dim_out
        norm_params_dim = hparams.num_coded_sps * (3 if hparams.add_deltas else
                                                   1)
        self.mean = nn.Parameter(
            torch.zeros(norm_params_dim),
            requires_grad=False)  # TODO: Should not appear in state_dict.
        self.std_dev = nn.Parameter(torch.ones(norm_params_dim),
                                    requires_grad=False)
        # self.dropout = hparams.dropout
        self.batch_first = hparams.batch_first
        self.batch_dim = 0 if hparams.batch_first else 1
        self.time_dim = 1 if hparams.batch_first else 0

        # Create hparams for pre-net.
        self.hparams_prenet = copy.deepcopy(hparams)
        self.hparams_prenet.model_type = hparams.pre_net_model_type
        self.hparams_prenet.model_name = hparams.pre_net_model_name
        self.hparams_prenet.model_path = hparams.pre_net_model_path
        # Remove embedding functions when they should not been passed.
        if not hparams.pass_embs_to_pre_net:
            self.hparams_prenet.f_get_emb_index = None

        # Create pre-net from type if not None, or try to load it by given path, or default path plus name.
        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch
        self.model_handler_prenet = ModelHandlerPyTorch()
        if self.hparams_prenet.model_type is not None:
            prenet_dim_in = (dim_in[0] - (0 if not hparams.f_get_emb_index
                                          or hparams.pass_embs_to_pre_net else
                                          len(hparams.f_get_emb_index)),
                             *dim_in[1:])
            self.model_handler_prenet.create_model(self.hparams_prenet,
                                                   prenet_dim_in, dim_out)
        elif self.hparams_prenet.model_path is not None:
            self.model_handler_prenet.model, *_ = self.model_handler_prenet.load_model(
                self.hparams_prenet.model_path,
                self.hparams_prenet,
                verbose=True)
        elif self.hparams_prenet.model_name is not None:
            self.hparams_prenet.model_path = os.path.join(
                self.hparams_prenet.out_dir, self.hparams_prenet.networks_dir,
                self.hparams_prenet.model_name)
            self.model_handler_prenet.model, *_ = self.model_handler_prenet.load_model(
                self.hparams_prenet.model_path,
                self.hparams_prenet,
                verbose=True)
        else:
            self.logger.warning("No pre-net specified.")

        if self.model_handler_prenet.model is not None:
            self.model_handler_prenet.model.save_intermediate_outputs = True  # Used by RNNDyn.
            self.add_module("pre_net", self.model_handler_prenet.model
                            )  # Properly register parameters of submodule.
            if not hparams.train_pre_net:
                for param in self.model_handler_prenet.model.parameters():
                    param.requires_grad = False

        self.prenet_group_index_of_alpha = -2
        self.embedding_dim = hparams.speaker_emb_dim
        if hparams.num_speakers is None:
            self.logger.warning(
                "Number of speaker is not defined. Assume only one speaker for embedding."
            )
            self.num_speakers = 1
        else:
            self.num_speakers = hparams.num_speakers
        self.pass_embs_to_pre_net = hparams.pass_embs_to_pre_net

        self.n = hparams.num_coded_sps
        self.alpha_range = 0.2
        self.has_deltas = hparams.add_deltas
        self.max_polynomial = 2 * self.n

        # Reuse pre-net embeddings or create new ones if non exist yet.
        if not hparams.pass_embs_to_pre_net or not self.model_handler_prenet.model:
            self.embeddings = nn.Embedding(self.num_speakers,
                                           self.embedding_dim)
        else:
            self.embeddings = self.model_handler_prenet.model.emb_groups[0]

        # Attach alpha layer to selected pre-net layer.
        if self.model_handler_prenet.model is not None:
            pre_net_layer_group = self.model_handler_prenet.model.layer_groups[
                self.prenet_group_index_of_alpha]
            self.alpha_layer = nn.Linear(
                pre_net_layer_group.out_dim *
                (2 if pre_net_layer_group.is_rnn else 1) + self.embedding_dim,
                1)
        else:
            self.alpha_layer = nn.Linear(
                np.prod(dim_in) + self.embedding_dim, 1)

        # self.alpha_layer = nn.Linear(53, 1)
        # self.alpha_layer = nn.Linear(self.embedding_dim, 1)

        # self.all_pass_warp_matrix = None
        # self.precision = 100
        # self.eps = 1e-45  # float(np.finfo(np.float32).eps)
        # self.pre_compute_warp_matrices(self.precision, requires_recursive_grad=True)

        self.computation_dtype = 'torch.FloatTensor'  # torch.float32 cannot be pickled.
        self.w_matrix_3d = self.gen_w_matrix_3d()

        # self.index_vec_pos = torch.arange(0, 2 * self.n, dtype=self.computation_dtype)
        # index_vec_neg_sign = torch.tensor([v * pow(-1, i) for i, v in enumerate(range(0, 2 * self.n))],
        #                                   dtype=self.computation_dtype, requires_grad=False).sign()
        # index_vec_neg_sign[0] = 1.
        #
        # self.w_matrix_3d_sign = self.w_matrix_3d.sign().type(self.computation_dtype)
        # self.w_matrix_3d_sign = torch.stack((self.w_matrix_3d_sign, self.w_matrix_3d_sign * index_vec_neg_sign[None, None, :]))
        # self.w_matrix_3d_log = torch.log(self.w_matrix_3d.abs()).type(self.index_vec_pos.dtype)

        self.w_matrix_3d = self.w_matrix_3d.type(self.computation_dtype)

        # self.compare_with_recursive(self.alpha_range)
        if self.use_gpu:
            self.w_matrix_3d = self.w_matrix_3d.cuda()
            # self.w_matrix_3d_sign = self.w_matrix_3d_sign.cuda()
            # self.w_matrix_3d_log = self.w_matrix_3d_log.cuda()
            # self.index_vec_pos = self.index_vec_pos.cuda()

    def set_norm_params(self, mean, std_dev):
        mean = torch.from_numpy(mean) if isinstance(mean, np.ndarray) else mean
        std_dev = torch.from_numpy(std_dev) if isinstance(
            std_dev, np.ndarray) else std_dev
        mean = mean.type(torch.float32)
        std_dev = std_dev.type(torch.float32)

        if self.use_gpu:
            mean = mean.cuda()
            std_dev = std_dev.cuda()

        self.mean = torch.nn.Parameter(mean, requires_grad=False)
        self.std_dev = torch.nn.Parameter(std_dev, requires_grad=False)

    def init_hidden(self, batch_size=1):
        return self.model_handler_prenet.model.init_hidden(batch_size)

    def set_gpu_flag(self, use_gpu):
        self.use_gpu = use_gpu
        if self.use_gpu:
            # self.alpha_list = self.alpha_list.cuda(async=True)  # Lazy loading.
            # self.warp_matrix_list = self.warp_matrix_list.cuda()  # Always required, no lazy loading.
            self.w_matrix_3d = self.w_matrix_3d.cuda()
            if self.mean is not None:
                self.mean = self.mean.cuda()
            if self.std_dev is not None:
                self.std_dev = self.std_dev.cuda()
        self.model_handler_prenet.model.set_gpu_flag(use_gpu)

    # def log_per_batch(self):
    #     emb = self.embeddings(torch.zeros(1).to(self.embeddings.weight.device).long()).unsqueeze(0)
    #
    #     alpha = self.alpha_layer(emb)
    #     alpha = torch.tanh(alpha) * self.alpha_range
    #     logging.info("alpha={}".format(alpha[0][0]))

    # def log_per_test(self):
    #     self.log_per_batch()

    def gen_w_matrix_3d(self):
        """
        Computes the entries with the formula for m-th row and k-th column:
        A(m, k) = 1/(k-1)! * sum_{n=max(0, k-m}}^k (k choose n) * (m+n-1)! / (m+n-k)! * (-1)^{n+k+m} alpha^{2n+m-k}
        The entries are stored as a vector corresponding to the polynomials of alpha (1, alpha, alpha^2,..., alpha^{M-1}).

        :return:     Values in warping matrix.
        """
        grad_matrix = np.zeros(
            (self.n, self.n, self.max_polynomial), dtype=np.float64
        )  #(np.float64 if self.n >= 32 else np.float32))  # 32! = 2.6E35

        grad_matrix[0, 0, 0] = 1.0
        max_degree = 0
        for m in range(0, self.n):
            for k in range(1, self.n):
                k_fac = 1 / math.factorial(k - 1) if k > 0 else 1
                for n in range(max(0, k - m), k + 1):
                    w = ncr(k, n) * math.pow(-1,
                                             n + m + k)  # * (2 * n + m - k)
                    w_fac = math.factorial(m + n - 1) if m + n - 1 > 0 else 1
                    w_fac /= math.factorial(m + n - k) if m + n - k > 0 else 1
                    w *= w_fac * k_fac
                    # if w != 0.0:
                    degree = 2 * n + m - k  # - 1
                    if degree < self.max_polynomial:
                        grad_matrix[m, k, degree] = w
                        # if degree > max_degree:
                        #     max_degree = degree
                        #     max_w = w
                        #     max_m = m
                        #     max_k = k

        # Deal with hugh factorials.
        # w_matrix_3d[w_matrix_3d == np.inf] = np.finfo('f').max
        # w_matrix_3d[w_matrix_3d == -np.inf] = -np.finfo('f').max
        grad_matrix = torch.from_numpy(grad_matrix)

        grad_matrix = torch.transpose(grad_matrix, 0, 1).contiguous()
        return grad_matrix

    def gen_warp_matrix_recursively(self, alpha, requires_recursive_grad=True):
        m = [[
            torch.empty((1, ),
                        dtype=alpha.dtype,
                        requires_grad=requires_recursive_grad)
            for x in range(self.n)
        ] for y in range(self.n)]
        n = self.n
        m[0][0] = torch.ones((1, ),
                             dtype=alpha.dtype,
                             requires_grad=requires_recursive_grad)
        for r in range(1, n):
            m[r][0] = m[r - 1][0] * alpha
        for c in range(1, n):
            m[0][c] = torch.zeros(
                (1, ),
                dtype=alpha.dtype,
                requires_grad=requires_recursive_grad)  # Fix for transpose.
            for r in range(1, n):
                m[r][c] = m[r - 1][c - 1] + alpha * (m[r - 1][c] - m[r][c - 1])

        return torch.cat([torch.cat(x) for x in m]).view(self.n, self.n)

    # def compare_with_recursive(self, alpha_range, precision=0.05, delta=0.001):
    #     """
    #     Compare the element-wise computed gradient matrix with the recursively generate matrix for alphas in
    #     range(-alpha_range, alpha_range, precision).
    #
    #     :param alpha_range:           Range of alpha to test in.
    #     :param precision:             Precision used for steps in that range.
    #     :param delta:                 Allowed delta of error.
    #
    #     :return:
    #     """
    #     assert(precision < 2 * alpha_range)  # Precision must fit in range.
    #
    #     for alpha_value in np.arange(-alpha_range, alpha_range + precision, precision):
    #         # Alpha value which receives the final gradient.
    #         alpha = torch.tensor(alpha_value, dtype=self.w_matrix_3d.dtype, requires_grad=True)
    #         alpha_eps = alpha
    #         alpha_eps = alpha_eps.repeat([1000, 1])
    #
    #         # Compute the warp matrix for each alpha.
    #         warp_matrix = self.get_warp_matrix_log(alpha_eps)
    #
    #         # Create the reference matrix recursively for the given alpha.
    #         ref_matrix = self.gen_warp_matrix_recursively(alpha)
    #
    #         # Compute the error.
    #         dist = (warp_matrix[10] - ref_matrix).abs()
    #         max_error = (dist / (ref_matrix.abs() + 1e-6)).max()
    #         error = dist.sum()
    #
    #         err_msg = "Max error between w_matrix_3d and recursive reference is {:.5f}% for alpha={:.2f}.".format(
    #             max_error * 100, alpha_value)
    #         logging.error(err_msg)
    #         if max_error > delta:
    #             raise ValueError(err_msg)
    #
    #         # Compute the gradient ratio error.
    #         ref_matrix.sum().backward()
    #         real_grad = torch.tensor(alpha.grad)
    #         alpha.grad.zero_()
    #         warp_matrix.sum().backward()
    #         approx_grad = alpha.grad / len(alpha_eps)
    #         dist_grad = (real_grad - approx_grad).abs()
    #         error_ratio = (dist_grad / real_grad.abs())
    #
    #         err_msg = "Gradient error between w_matrix_3d and recursive reference is {:.5f}% for alpha={:.2f}.".format(
    #             error_ratio * 100., alpha_value)
    #         logging.error(err_msg)
    #         if error_ratio > delta:
    #             raise ValueError(err_msg)
    #
    #     return True

    def pre_compute_warp_matrices(self,
                                  precision,
                                  requires_recursive_grad=True):
        """
        Recursively pre-compute warping matrices for [-1, 1] with given precision.
        Unpractical because recursive backwards pass takes too long.
        """

        self.warp_matrix_list = list()
        self.alpha_list = list()

        for alpha in range(-1 * precision, 1 * precision, 1):
            alpha = torch.tensor(alpha,
                                 dtype=torch.float32,
                                 requires_grad=requires_recursive_grad)
            self.alpha_list.append(alpha.unsqueeze(0) + self.precision)
            self.warp_matrix_list.append(
                self.gen_warp_matrix_recursively(
                    alpha / precision, requires_recursive_grad).unsqueeze(
                        0))  # Create "continuous" matrix.

        self.warp_matrix_list = torch.cat(self.warp_matrix_list)
        self.alpha_list = torch.cat(self.alpha_list)

        if not requires_recursive_grad:
            self.warp_matrix_list.requires_grad_(True)
            self.alpha_list.requires_grad_(True)

        if self.use_gpu:
            self.alpha_list = self.alpha_list.cuda(async=True)  # Lazy loading.
            self.warp_matrix_list = self.warp_matrix_list.cuda(
            )  # Always required, no lazy loading.

    def get_warp_matrix_index(self, alphas):
        """
        Compute warping matrix for vector of alphas in log space.

        :param alphas:       Vector of alphas with time and batch dimension merged (TB x 1).
        :return:             Warping matrix for each alpha value with merged time and batch dimension (TB x n x n).
        """

        alphas = (alphas + 1.) * self.precision

        alpha0 = alphas.floor().detach().squeeze(-1)
        alpha1 = alpha0 + 1

        warp_matrix0 = self.warp_matrix_list[alpha0.long()]
        warp_matrix1 = self.warp_matrix_list[alpha1.long()]

        W0 = (alpha1.unsqueeze(-1) -
              alphas).unsqueeze(-1).expand_as(warp_matrix0)
        W1 = (alphas -
              alpha0.unsqueeze(-1)).unsqueeze(-1).expand_as(warp_matrix1)

        warp_matrix = W0 * warp_matrix0 + W1 * warp_matrix1  # Doesn't work with negative indices.
        warp_matrix = warp_matrix.view(
            -1, self.n,
            self.n)  # Merge time and batch dimension to use torch.bmm().

        return warp_matrix

    def get_warp_matrix_log(self, alphas):
        """
        Compute warping matrix for vector of alphas in log space.

        :param alphas:       Vector of alphas with time and batch dimension merged (TB x 1).
        :return:             Warping matrix for each alpha value with merged time and batch dimension (TB x n x n).
        """

        # Compute log of alpha^{0..2*self.n} with alpha==0 save.
        # alphas[alphas == 0] = alphas[alphas == 0] + self.eps
        log_alpha = alphas.abs().log()
        alpha_vec = torch.mm(log_alpha, self.index_vec_pos.view(1,
                                                                -1))  # TB x 2N

        # Compute elements of sum of warping matrix in third dimension.
        w_matrix_3d_expanded = self.w_matrix_3d_log.expand(
            alpha_vec.shape[0], *self.w_matrix_3d_log.shape)  # TB x n x n x 2n
        w_matrix_3d_alpha = w_matrix_3d_expanded + alpha_vec[:, None, None, :]
        w_matrix_3d_alpha = w_matrix_3d_alpha.exp()

        # Apply the correct sign to the elements in third dimension.
        alpha_positive_indices = alphas[:, 0] < 0
        w_matrix_3d_alpha = torch.index_select(
            self.w_matrix_3d_sign, dim=0,
            index=alpha_positive_indices.long()) * w_matrix_3d_alpha

        # Compute actual warping matrix.
        warp_matrix = w_matrix_3d_alpha.sum(dim=3)

        return warp_matrix

    def get_warp_matrix(self, alphas):
        """
        Compute warping matrix for vector of alphas.

        :param alphas:       Vector of alphas with time and batch dimension merged (TB x 1).
        :return:             Warping matrix for each alpha value with merged time and batch dimension (TB x n x n).
        """

        # Create alpha polynomial vector.
        alpha_list = [
            torch.ones((alphas.shape),
                       dtype=self.w_matrix_3d.dtype,
                       device=alphas.device,
                       requires_grad=True)
        ]
        for i in range(1, self.max_polynomial):
            alpha_list.append(alpha_list[i - 1] * alphas)
        alpha_vec = torch.cat(alpha_list, dim=1).unsqueeze(-1)  # T x 2n x 1

        # Do a batched matrix multiplication to get the elements of the warp matrix.
        grad_matrix_flat = self.w_matrix_3d.view(
            self.n * self.n,
            self.w_matrix_3d.shape[-1])  # n x n x 2n -> n^2 x 2n
        grad_matrix_ext_flat = grad_matrix_flat.expand(
            alpha_vec.shape[0], *grad_matrix_flat.shape[0:])  # TB x n^2 x 2n
        warp_matrix = torch.matmul(grad_matrix_ext_flat,
                                   alpha_vec).view(-1, self.n,
                                                   self.n)  # TB x n x n

        return warp_matrix

    def forward_sample(self, in_tensor, alphas=None):
        """Forward one tensor through the layer."""
        if isinstance(in_tensor, np.ndarray):
            in_tensor = torch.from_numpy(in_tensor)
        in_tensor = in_tensor[:, None].to(self.w_matrix_3d.device)

        if alphas is not None:
            if isinstance(alphas, np.ndarray):
                alphas = torch.from_numpy(alphas)
            alphas = alphas[:, None].to(self.w_matrix_3d.device)

        return self.forward(in_tensor,
                            hidden=None,
                            seq_length_input=(len(in_tensor), ),
                            max_length_input=(len(in_tensor), ),
                            alphas=alphas)

    def forward(self,
                inputs,
                hidden,
                seq_length_input,
                max_length_input,
                target=None,
                seq_lengths_output=None,
                alphas=None):

        batch_size = inputs.shape[self.batch_dim]
        # num_frames = inputs.shape[self.time_dim]

        # Code for testing fixed alphas.
        if alphas is not None:
            alphas = alphas.to(self.w_matrix_3d.device)
            inputs = inputs.to(self.w_matrix_3d.device)
            output = inputs.type(self.w_matrix_3d.dtype)
            group_output = inputs.type(self.w_matrix_3d.dtype)
        else:
            inputs_emb = inputs[:, :, -1]
            if not self.pass_embs_to_pre_net:
                inputs = inputs[:, :, :-1]
            output, hidden = self.model_handler_prenet.model(
                inputs, hidden, seq_length_input, max_length_input, target,
                seq_lengths_output)
            group_output = self.model_handler_prenet.model.layer_groups[
                self.prenet_group_index_of_alpha].output

            group_output = group_output.view(
                output.shape[self.time_dim], batch_size, -1
            )  # View operation to get rid of possible bidirectional outputs.

            emb = self.embeddings(inputs_emb.long(
            ))  #[None, ...]  # Use speaker 0 for everything for now.
            #emb = emb.expand(-1, group_output.shape[1], -1) if self.batch_first else emb.expand(group_output.shape[0], -1, -1)  # Expand the temporal dimension.
            emb_out = torch.cat((emb, group_output), dim=2)

            alphas = self.alpha_layer(emb_out)
            # alphas = self.alpha_layer(inputs[:, :, 86:347:5])
            alphas = torch.tanh(alphas) * self.alpha_range
            # alphas = torch.zeros((*output.shape[:2], 1), device=output.device)

        alphas = alphas.view(-1, 1).type(
            self.w_matrix_3d.dtype)  # Merge time and batch dimension.
        warp_matrix = self.get_warp_matrix(alphas)

        if self.has_deltas:
            warped_feature_list = list()
            for start_index in range(0, 3):
                feature = output[:, :, start_index * self.n:(start_index + 1) *
                                 self.n]  # Select spectral features.

                # Denormalize before warping.
                if self.std_dev is not None:
                    feature = feature * self.std_dev[start_index *
                                                     self.n:(start_index + 1) *
                                                     self.n]
                if self.mean is not None:
                    feature = feature + self.mean[start_index * self.n:
                                                  (start_index + 1) * self.n]
                feature[:, :, 0::self.
                        n] /= 2.  # Adaptation for single-sided spectrogram.

                # Merge time and batch axis, do batched vector matrix multiplication with a (1 x N * N x N) matrix
                # multiplication, split time and batch axis back again.
                feature_warped = torch.bmm(
                    feature.view(-1, 1, *feature.shape[2:]),
                    warp_matrix).view(-1, batch_size, *feature.shape[2:])

                feature_warped[:, :, 0::self.
                               n] *= 2.  # Adaptation for single-sided spectrogram.
                # Normalize again for further processing.
                if self.mean is not None:
                    feature_warped = feature_warped - self.mean[
                        start_index * self.n:(start_index + 1) * self.n]
                if self.std_dev is not None:
                    feature_warped = feature_warped / self.std_dev[
                        start_index * self.n:(start_index + 1) * self.n]

                warped_feature_list.append(feature_warped)
            output = torch.cat(
                (*warped_feature_list, output[:, :, 3 * self.n:]), dim=2)
        else:
            feature = output[:, :, :self.n]  # Select spectral features.

            # Denormalize before warping.
            if self.std_dev is not None:
                feature = feature * self.std_dev
            if self.mean is not None:
                feature = feature + self.mean
            feature[:, :, 0] /= 2.  # Adaptation for single-sided spectrogram.

            # Merge time and batch axis, do batched vector matrix multiplication with a (1 x N * N x N) matrix
            # multiplication, split time and batch axis back again.
            feature_warped = torch.bmm(feature.view(-1, 1, *feature.shape[2:]),
                                       warp_matrix).squeeze(1).view(
                                           -1, batch_size, *feature.shape[2:])

            feature_warped[:, :,
                           0] *= 2.  # Adaptation for single-sided spectrogram.
            # Normalize again for further processing.
            if self.mean is not None:
                feature_warped = feature_warped - self.mean
            if self.std_dev is not None:
                feature_warped = feature_warped / self.std_dev

            output = torch.cat((feature_warped, output[:, :, self.n:]), dim=2)

        return output, (hidden, alphas.view(-1, batch_size))
Exemplo n.º 8
0
    def run_wavenet_vocoder(synth_output, hparams):
        # Import ModelHandlerPyTorch here to prevent circular dependencies.
        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch

        assert hparams.synth_vocoder_path is not None, "Please set path to neural vocoder in hparams.synth_vocoder_path"
        # Add identifier to suffix.
        old_synth_file_suffix = hparams.synth_file_suffix
        hparams.synth_file_suffix += '_' + hparams.synth_vocoder

        if not hasattr(hparams, 'bit_depth'):
            hparams.add_hparam("bit_depth", 16)

        synth_output = copy.copy(synth_output)

        input_fs_Hz = 1000.0 / hparams.frame_size_ms
        assert hasattr(hparams, "frame_rate_output_Hz") and hparams.frame_rate_output_Hz is not None, \
            "hparams.frame_rate_output_Hz has to be set and match the trained WaveNet."
        in_to_out_multiplier = hparams.frame_rate_output_Hz / input_fs_Hz
        # # dir_world_features = os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
        input_gen = WorldFeatLabelGen(
            None,
            add_deltas=False,
            sampling_fn=partial(sample_linearly,
                                in_to_out_multiplier=in_to_out_multiplier,
                                dtype=np.float32))
        # Load normalisation parameters for wavenet input.
        try:
            norm_params_path = os.path.splitext(
                hparams.synth_vocoder_path)[0] + "_norm_params.npy"
            input_gen.norm_params = np.load(norm_params_path).reshape(2, -1)
        except FileNotFoundError:
            logging.error(
                "Cannot find normalisation parameters for WaveNet input at {}."
                "Please save them there with numpy.save().".format(
                    norm_params_path))
            raise

        model_handler = ModelHandlerPyTorch()
        model_handler.model, *_ = model_handler.load_model(
            hparams.synth_vocoder_path, hparams, verbose=False)

        for id_name, output in synth_output.items():
            logging.info("Synthesise {} with {} vocoder.".format(
                id_name, hparams.synth_vocoder_path))

            # Any other post-processing could be done here.

            # Normalize input.
            output = input_gen.preprocess_sample(output)

            # output (T x C) --transpose--> (C x T) --unsqueeze(0)--> (B x C x T).
            output = output.transpose()[None, ...]
            # Wavenet input has to be (B x C x T).
            output, _ = model_handler.forward(
                output, hparams, batch_seq_lengths=(output.shape[-1], ))
            # output, _ = model_handler.forward(output[:, :, :1000], hparams, batch_seq_lengths=(1000,))  # DEBUG
            output = output[0].transpose(
            )  # Remove batch dim and transpose back to (T x C).

            out_channels = output.shape[1]
            if out_channels > 1:  # Check if the output is one-hot (quantized) or 1 (raw).
                # Revert mu-law quantization.
                output = output.argmax(axis=1)
                synth_output[
                    id_name] = RawWaveformLabelGen.mu_law_companding_reversed(
                        output, out_channels)

            # Save the audio.
            wav_file_path = os.path.join(
                hparams.synth_dir, "".join(
                    (os.path.basename(id_name).rsplit('.', 1)[0], "_",
                     hparams.model_name, hparams.synth_file_suffix, ".",
                     hparams.synth_ext)))
            Synthesiser.raw_to_file(wav_file_path, synth_output[id_name],
                                    hparams.synth_fs, hparams.bit_depth)

        # Restore identifier.
        hparams.setattr_no_type_check(
            "synth_file_suffix",
            old_synth_file_suffix)  # Can be None, thus no type check.
Exemplo n.º 9
0
    shutil.rmtree(dir_atoms)
    makedirs_safe(dir_atoms)
    atom_generator = AtomVUVDistPosLabelGen(
        os.path.join(os.path.dirname(os.environ["IDIAPTTS_ROOT"]), "tools",
                     "wcad"), dir_atoms, dir_world, thetas)
    atom_generator.gen_data(dir_wav, dir_atoms, id_list=id_list)

if retrain_models:
    raise NotImplementedError("Did not yet implemented retraining of models.")
elif save_models:
    from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch

    from idiaptts.src.model_trainers.wcad.AtomVUVDistPosModelTrainer import AtomVUVDistPosModelTrainer
    hparams = AtomVUVDistPosModelTrainer.create_hparams()
    hparams.model_name = "test_model_in409_out7.nn"
    model_handler = ModelHandlerPyTorch()
    # The following code uses the load_model method and saves it back as a checkpoint.
    # model, model_type, dim_in, dim_out = model_handler.load_model(hparams.model_name, hparams)
    # model_handler.model_type = "RNNDYN-1_RELU_32-1_FC_7"
    # model_handler.dim_in = model.dim_in
    # model_handler.dim_out = model.dim_out
    # model_handler.model_name = hparams.model_name
    # model_handler.model = model
    # model_handler.save_checkpoint(os.path.realpath(hparams.model_name), 3)
    epochs = model_handler.load_checkpoint(hparams.model_name, hparams)
    model_handler.save_checkpoint(os.path.realpath(hparams.model_name), epochs)

    from idiaptts.src.model_trainers.wcad.AtomNeuralFilterModelTrainer import AtomNeuralFilterModelTrainer
    hparams = AtomNeuralFilterModelTrainer.create_hparams()
    hparams.model_name = "neural_filters_model_in409_out2.nn"
    hparams.atom_model_path = "test_model_in409_out7.nn"
Exemplo n.º 10
0
class NeuralFilters(nn.Module):
    IDENTIFIER = "NeuralFilters"

    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        # Store parameters.
        self.use_gpu = hparams.use_gpu
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dropout = hparams.dropout

        assert(not hparams.batch_first)  # This implementation doesn't work with batch_first.

        from idiaptts.src.neural_networks.pytorch.ModelHandlerPyTorch import ModelHandlerPyTorch
        self.model_handler_atoms = ModelHandlerPyTorch()
        if hasattr(hparams.hparams_atom, "learning_rate"):
            lr = hparams.hparams_atom.learning_rate
        elif hasattr(hparams.hparams_atom, "optimiser_args"):
            lr = hparams.hparams_atom.optimiser_args["lr"]
        elif hasattr(hparams, "learning_rate"):
            lr = hparams.learning_rate
        elif hasattr(hparams.optimiser_args, "lr"):
            lr = hparams.optimiser_args["lr"]
        else:
            lr = None
        self.model_handler_atoms.load_checkpoint(hparams.atom_model_path, hparams.hparams_atom, initial_lr=lr)
        self.add_module("atom_model", self.model_handler_atoms.model)  # Add atom model as submodule so that parameters are properly registered.

        if hparams.complex_poles:
            self.intonation_filters = ComplexModel(hparams.thetas, hparams.phase_init)
        else:
            self.intonation_filters = CriticalModel(hparams.thetas)
        self.add_module("intonation_filters", self.intonation_filters)

    def forward(self, inputs, hidden, seq_lengths, max_lenght_inputs, *_):
        output_atoms, output_atoms_hidden = self.model_handler_atoms.model(inputs, hidden, seq_lengths, max_lenght_inputs)

        vuv = output_atoms[:, :, 0:1]
        amps = output_atoms[:, :, 1:-1]
        # pos = output_atoms[:, :, -1]

        if len(seq_lengths) > 1:
            pack_amps = pack_padded_sequence(amps, seq_lengths)  # TODO: Add batch_first parameter.

            output_filters = self.intonation_filters(pack_amps)  # The filter unpacks the sequence.

            # output, _ = pad_packed_sequence(output_filters, total_length=max_lenght_inputs)
            # # Pack sequence.
            # pack_amps = amps.squeeze().split(seq_lengths, dim=0)
            # pack_amps = pack_sequence(pack_amps)
            # # Run through filter.
            # output_filters = self.intonation_filters(pack_amps)
            # # Unpack sequence.
            # # output_filters, _ = pad_packed_sequence(output_filters)
            # output_filters = torch.cat([x[:seq_lengths[i], :] for i, x in enumerate(output_filters.split(1, dim=1))])
        else:
            output_filters = self.intonation_filters(amps)

        output_e2e = torch.cat((output_filters, vuv, amps), -1)

        return output_e2e, None

    def filters_forward(self, inputs, hidden, seq_lengths, max_length):
        """Get output of each filter without their superposition."""
        output_atoms, output_atoms_hidden = self.model_handler_atoms.model(inputs, hidden, seq_lengths, max_length)

        amps = output_atoms[:, :, 1:-1]

        if len(seq_lengths) > 1:
            # Pack sequence.
            pack_amps = pack_padded_sequence(amps, seq_lengths)  # TODO: Add batch_first parameter.
            # Run through filter.
            output_filters = self.intonation_filters.filters_forward(pack_amps)  # The filter unpacks the sequence.
        else:
            output_filters = self.intonation_filters.filters_forward(amps)

        return output_filters

    def set_gpu_flag(self, use_gpu):
        self.use_gpu = use_gpu
        self.model_handler_atoms.use_gpu = use_gpu
        self.model_handler_atoms.model.set_gpu_flag(use_gpu)

    def init_hidden(self, batch_size=1):
        self.model_handler_atoms.model.init_hidden(batch_size)
        return None

    def thetas_approx(self):
        roots = [np.roots(denom) for denom in self.intonation_filters.filters.denominator]
        modulus = np.array([np.abs(root[0]) for root in roots])
        return modulus_to_theta(modulus)
Exemplo n.º 11
0
    def __init__(self, id_list, hparams):
        """Default constructor.

        :param id_list:              List or tuple of ids as strings. This list is separated into evaluation, test, and training set.
        :param hparams:              An object holding all hyper parameters.
        """

        self.logger.info("Running on host {}.".format(platform.node()))
        try:
            repo = git.Repo(search_parent_directories=True)
            self.logger.info("Git: {} at {}".format(repo.git_dir,
                                                    repo.head.object.hexsha))
        except git.exc.InvalidGitRepositoryError:
            pass
        try:
            framework_repo = git.Repo(path=os.environ['IDIAPTTS_ROOT'],
                                      search_parent_directories=True)
            self.logger.info("IdiapTTS framework git: {} at {}".format(
                framework_repo.git_dir, framework_repo.head.object.hexsha))
        except git.exc.InvalidGitRepositoryError:
            pass

        assert (hparams is not None)

        if not hasattr(hparams,
                       "batch_size_train") or not hparams.batch_size_train > 1:
            hparams.variable_sequence_length_train = False
        if not hasattr(hparams,
                       "batch_size_val") or not hparams.batch_size_val > 1:
            hparams.variable_sequence_length_val = False

        if hparams.use_gpu:
            if hparams.num_gpus > 1:
                os.environ['CUDA_VISIBLE_DEVICES'] = str(
                    tuple(range(hparams.num_gpus)))
            if ModelHandlerPyTorch.cuda_is_available():
                device_count = ModelHandlerPyTorch.device_count()
                if not device_count == hparams.num_gpus:
                    self.logger.error(
                        "Specified GPU count in hparams.num_gpus ({}) doesn't match hardware ({})."
                        .format(hparams.num_gpus, device_count))
                assert (device_count == hparams.num_gpus
                        )  # Specified GPU count doesn't match hardware.
            else:
                self.logger.warning(
                    "No CUDA device available, use CPU mode instead.")
                hparams.use_gpu = False

        if "lr" not in hparams.optimiser_args\
                and hasattr(hparams, "learning_rate")\
                and hparams.learning_rate is not None:  # Backwards compatibility.
            hparams.optimiser_args["lr"] = hparams.learning_rate

        if hparams.seed is not None:
            ModelHandlerPyTorch.seed(hparams.seed)  # Seed the backend.
            np.random.seed(hparams.seed)
            random.seed(hparams.seed)

        if not hasattr(self, "id_list_train") or self.id_list_train is None:
            id_list_shuffled = id_list
            if hparams.seed is not None:
                id_list_shuffled = random.sample(id_list, len(id_list))

            # Partition (randomly sorted) ids into [val_set, train_set, test_set].
            assert (hparams.test_set_perc + hparams.val_set_perc < 1)
            if hparams.val_set_perc > 0.0:
                num_valset = max(
                    1, int(len(id_list_shuffled) * hparams.val_set_perc))
                self.id_list_val = id_list_shuffled[:num_valset]
            else:
                num_valset = 0
                self.id_list_val = None
            if hparams.test_set_perc > 0.0:
                num_testset = max(
                    1, int(len(id_list_shuffled) * hparams.test_set_perc))
                self.id_list_test = id_list_shuffled[-num_testset:]
            else:
                num_testset = 0
                self.id_list_test = None
            self.id_list_train = id_list_shuffled[num_valset:-num_testset] if num_testset > 0\
                                                                           else id_list_shuffled[num_valset:]
            assert (len(self.id_list_train) > 0)

        # Create and initialize model.
        self.logger.info("Create ModelHandler.")
        self.model_handler = ModelHandlerPyTorch(
        )  # A handler for the NN models depending on the NN frameworks.

        # Data attributes.
        self.InputGen = None  # Used in the datasets.
        self.OutputGen = None  # Used in the datasets.
        self.dataset_train = None
        self.dataset_val = None

        self.batch_collate_fn = None  # Function used to combine samples to one batch.
        self.batch_decollate_fn = None  # Function used to split the batched output of a model.
        # Result is directly given to the gen_figure function.
        # Only the first element of the result is given to the post-processing function
        # of the OutputGen when result is a tuple or list.

        if not hasattr(
                self, "loss_function"
        ):  # Could have been set already in constructor of child class.
            self.loss_function = None  # Has to be defined by subclass.

        self.total_epoch = None  # Total number of epochs the current model was trained.
Exemplo n.º 12
0
class ModelTrainer(object):
    """
    Baseclass for all trainers.

    Load input and output data by generators (set by subclass). Perform normalisation and length mismatch fix.
    Select 5% of data (at least one) for testset. The model_handler implementation of a framework is used as interface.
    Subclasses have to set up the data and synthesize attributes by instantiating generators and possibly overwriting
    the synthesize method.
    """
    logger = logging.getLogger(__name__)

    # Default is hparams.out_dir/self.dir_extracted_acoustic_features, but can be overwritten by hparams.world_dir.
    dir_extracted_acoustic_features = "../WORLD/"

    def __init__(self, id_list, hparams):
        """Default constructor.

        :param id_list:              List or tuple of ids as strings. This list is separated into evaluation, test, and training set.
        :param hparams:              An object holding all hyper parameters.
        """

        self.logger.info("Running on host {}.".format(platform.node()))
        try:
            repo = git.Repo(search_parent_directories=True)
            self.logger.info("Git: {} at {}".format(repo.git_dir,
                                                    repo.head.object.hexsha))
        except git.exc.InvalidGitRepositoryError:
            pass
        try:
            framework_repo = git.Repo(path=os.environ['IDIAPTTS_ROOT'],
                                      search_parent_directories=True)
            self.logger.info("IdiapTTS framework git: {} at {}".format(
                framework_repo.git_dir, framework_repo.head.object.hexsha))
        except git.exc.InvalidGitRepositoryError:
            pass

        assert (hparams is not None)

        if not hasattr(hparams,
                       "batch_size_train") or not hparams.batch_size_train > 1:
            hparams.variable_sequence_length_train = False
        if not hasattr(hparams,
                       "batch_size_val") or not hparams.batch_size_val > 1:
            hparams.variable_sequence_length_val = False

        if hparams.use_gpu:
            if hparams.num_gpus > 1:
                os.environ['CUDA_VISIBLE_DEVICES'] = str(
                    tuple(range(hparams.num_gpus)))
            if ModelHandlerPyTorch.cuda_is_available():
                device_count = ModelHandlerPyTorch.device_count()
                if not device_count == hparams.num_gpus:
                    self.logger.error(
                        "Specified GPU count in hparams.num_gpus ({}) doesn't match hardware ({})."
                        .format(hparams.num_gpus, device_count))
                assert (device_count == hparams.num_gpus
                        )  # Specified GPU count doesn't match hardware.
            else:
                self.logger.warning(
                    "No CUDA device available, use CPU mode instead.")
                hparams.use_gpu = False

        if "lr" not in hparams.optimiser_args\
                and hasattr(hparams, "learning_rate")\
                and hparams.learning_rate is not None:  # Backwards compatibility.
            hparams.optimiser_args["lr"] = hparams.learning_rate

        if hparams.seed is not None:
            ModelHandlerPyTorch.seed(hparams.seed)  # Seed the backend.
            np.random.seed(hparams.seed)
            random.seed(hparams.seed)

        if not hasattr(self, "id_list_train") or self.id_list_train is None:
            id_list_shuffled = id_list
            if hparams.seed is not None:
                id_list_shuffled = random.sample(id_list, len(id_list))

            # Partition (randomly sorted) ids into [val_set, train_set, test_set].
            assert (hparams.test_set_perc + hparams.val_set_perc < 1)
            if hparams.val_set_perc > 0.0:
                num_valset = max(
                    1, int(len(id_list_shuffled) * hparams.val_set_perc))
                self.id_list_val = id_list_shuffled[:num_valset]
            else:
                num_valset = 0
                self.id_list_val = None
            if hparams.test_set_perc > 0.0:
                num_testset = max(
                    1, int(len(id_list_shuffled) * hparams.test_set_perc))
                self.id_list_test = id_list_shuffled[-num_testset:]
            else:
                num_testset = 0
                self.id_list_test = None
            self.id_list_train = id_list_shuffled[num_valset:-num_testset] if num_testset > 0\
                                                                           else id_list_shuffled[num_valset:]
            assert (len(self.id_list_train) > 0)

        # Create and initialize model.
        self.logger.info("Create ModelHandler.")
        self.model_handler = ModelHandlerPyTorch(
        )  # A handler for the NN models depending on the NN frameworks.

        # Data attributes.
        self.InputGen = None  # Used in the datasets.
        self.OutputGen = None  # Used in the datasets.
        self.dataset_train = None
        self.dataset_val = None

        self.batch_collate_fn = None  # Function used to combine samples to one batch.
        self.batch_decollate_fn = None  # Function used to split the batched output of a model.
        # Result is directly given to the gen_figure function.
        # Only the first element of the result is given to the post-processing function
        # of the OutputGen when result is a tuple or list.

        if not hasattr(
                self, "loss_function"
        ):  # Could have been set already in constructor of child class.
            self.loss_function = None  # Has to be defined by subclass.

        self.total_epoch = None  # Total number of epochs the current model was trained.

    @staticmethod
    def create_hparams(hparams_string=None, verbose=False):
        """Create model hyper-parameters. Parse non-default from given string."""
        return ExtendedHParams.create_hparams(hparams_string, verbose)

    # def plot_outputs(self, max_epochs, id_name, outputs, target):
    #     plotter = DataPlotter()
    #     net_name = os.path.basename(self.model_handler.model_name)
    #     filename = str(os.path.join(self.out_dir, id_name + '.' + net_name))
    #     plotter.set_title(id_name + " - " + net_name)
    #
    #     # Create a plot for every dimension of outputs with its target.
    #     graphs_o = [None] * target.shape[1]
    #     graphs_t = [None] * target.shape[1]
    #     for out_idx in range(0, target.shape[1]):
    #         graphs_o[out_idx] = list()
    #         # Add all outputs to the plot.
    #         for idx, o in enumerate(outputs):
    #             # Handle special case where NN output has only one dimension.
    #             if len(o.shape) == 1:
    #                 o = o.reshape(-1, 1)
    #             graphs_o[out_idx].append((o[:, out_idx], 'e' + str(min(max_epochs, (idx + 1) * self.epochs_per_plot))))
    #         # Give data to plotter and leave two grid position for each output dimension (used for output and target).
    #         plotter.set_data_list(grid_idx=out_idx * 2, data_list=graphs_o[out_idx])
    #
    #         # Add target belonging to the output dimension.
    #         graphs_t[out_idx] = list()
    #         graphs_t[out_idx].append((target[:, out_idx], 'target[' + str(out_idx) + ']'))
    #         plotter.set_data_list(grid_idx=out_idx * 2 + 1, data_list=graphs_t[out_idx])
    #
    #     # Set label for all.
    #     plotter.set_label(xlabel='frames', ylabel='amp')
    #
    #     # Generate and save the plot.
    #     plotter.gen_plot()
    #     plotter.save_to_file(filename + ".OUTPUTS.png")
    #
    #     plotter.plt.show()

    def init(self, hparams):

        self.logger.info(
            "CPU memory: " +
            str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3) +
            " MB.")
        if hparams.use_gpu:
            self.logger.info("GPU memory: " + str(get_gpu_memory_map()) +
                             " MB.")

        # Create the necessary directories.
        makedirs_safe(
            os.path.join(hparams.out_dir, hparams.networks_dir,
                         hparams.checkpoints_dir))

        # Create the default model path if not set or retrieve the name from the given path.
        if hparams.model_path is None:
            assert (
                hparams.model_name is not None
            )  # A model_path or model_name has to be given. No default exists.
            hparams.model_path = os.path.join(hparams.out_dir,
                                              hparams.networks_dir,
                                              hparams.model_name)
        elif hparams.model_name is None:
            hparams.model_name = os.path.basename(hparams.model_path)

        model_path_out = os.path.join(hparams.out_dir, hparams.networks_dir,
                                      hparams.model_name)
        if hparams.epochs <= 0:
            # Try to load the model. If it doesn't exist, create a new one and save it.
            # Return the loaded/created model, because no training was requested.
            try:
                self.total_epoch = self.model_handler.load_checkpoint(
                    hparams.model_path, hparams, hparams.optimiser_args["lr"]
                    if hasattr(hparams, "optimiser_args")
                    and "lr" in hparams.optimiser_args else None)
            except FileNotFoundError:
                if hparams.model_type is None:
                    self.logger.error(
                        "Model does not exist at {} and you didn't give model_type to create a new one."
                        .format(hparams.model_path))
                    raise  # This will rethrow the last exception.
                else:
                    self.logger.warning(
                        'Model does not exist at {}. Creating a new one instead and saving it.'
                        .format(hparams.model_path))
                    dim_in, dim_out = self.dataset_train.get_dims()
                    self.model_handler.create_model(hparams, dim_in, dim_out)
                    self.total_epoch = 0
                    self.model_handler.save_checkpoint(model_path_out,
                                                       self.total_epoch)

            self.logger.info("Model ready.")
            return

        if hparams.model_type is None:
            self.total_epoch = self.model_handler.load_checkpoint(
                hparams.model_path, hparams, hparams.optimiser_args["lr"]
                if hasattr(hparams, "optimiser_args")
                and "lr" in hparams.optimiser_args else None)
        else:
            dim_in, dim_out = self.dataset_train.get_dims()
            self.model_handler.create_model(hparams, dim_in, dim_out)
            self.total_epoch = 0

        self.logger.info("Model ready.")

    def train(self, hparams):
        """
        Train the model. Use generators for data preparation and model_handler for access.
        Generators have to be set in constructor of subclasses.

        :param hparams:          Hyper-parameter container.
        :return:                 A tuple of (all test loss, all training loss, the model_handler object).
        """

        hparams.verify(
        )  # Verify that attributes were added correctly, print warning for wrongly initialized ones.
        self.logger.info(hparams.get_debug_string())

        assert (self.model_handler
                )  # The init function has be called before training.

        # Skip training if epochs is not greater 0.
        if hparams.epochs <= 0:
            self.logger.info(
                "Number of training epochs is {}. Skipping training.".format(
                    hparams.epochs))
            return list(), list(), self.model_handler

        # Log evaluation ids.
        if len(self.id_list_val) > 0:
            valset_keys = sorted(self.id_list_val)
            self.logger.info(
                "Validation set (" + str(len(valset_keys)) + "): " + " ".join([
                    os.path.join(
                        os.path.split(os.path.dirname(id_name))[-1],
                        os.path.splitext(os.path.basename(id_name))[0])
                    for id_name in valset_keys
                ]))
        # Log test ids.
        testset_keys = sorted(self.id_list_test)
        self.logger.info(
            "Test set (" + str(len(testset_keys)) + "): " + " ".join([
                os.path.join(
                    os.path.split(os.path.dirname(id_name))[-1],
                    os.path.splitext(os.path.basename(id_name))[0])
                for id_name in testset_keys
            ]))

        # Setup the dataloaders.
        self.model_handler.set_dataset(hparams, self.dataset_train,
                                       self.dataset_val, self.batch_collate_fn)

        self.logger.info(
            "CPU memory: " +
            str(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1e3) +
            " MB.")
        if hparams.use_gpu:
            self.logger.info("GPU memory: " + str(get_gpu_memory_map()) +
                             " MB.")

        # Run model.
        # if self.epochs_per_plot > 0:
        #     outputs = list()
        #     num_iterations = int(math.ceil(float(epochs) / float(self.epochs_per_plot)))
        #     epochs_per_iter = min(epochs, self.epochs_per_plot)
        #     for e in range(num_iterations):
        #         epochs_this_iter = min(epochs_per_iter, epochs - e * epochs_per_iter)
        #         nn_model.run(epochs_this_iter, e * epochs_per_iter)
        #         outputs.append(nn_model.forward(dict_input_labels[self.plot_per_epoch_id_name]))
        #     self.plot_outputs(epochs, self.plot_per_epoch_id_name, outputs, dict_output_labels[self.plot_per_epoch_id_name])

        # Some sanity checks.
        if hparams.epochs_per_scheduler_step:
            if hparams.epochs_per_test > hparams.epochs_per_scheduler_step:
                self.logger.warning(
                    "Model is validated only every {} epochs, ".format(
                        hparams.epochs_per_test) +
                    "but scheduler is supposed to run every {} epochs.".format(
                        hparams.epochs_per_scheduler_step))
            if hparams.epochs_per_test % hparams.epochs_per_scheduler_step != 0:
                self.logger.warning(
                    "hparams.epochs_per_test ({}) % hparams.epochs_per_scheduler_step ({}) != 0. "
                    .format(hparams.epochs_per_test,
                            hparams.epochs_per_scheduler_step) +
                    "Note that the scheduler is only run when current_epoch % "
                    +
                    "hparams.epochs_per_scheduler_step == 0. Therefore hparams.epochs_per_scheduler_step "
                    + "should be a factor of hparams.epochs_per_test.")

        t_start = timer()
        self.logger.info('Start training: {}'.format(
            datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

        self.model_handler.set_optimiser(hparams)
        self.model_handler.set_scheduler(
            hparams,
            self.total_epoch if hparams.use_saved_learning_rate else 0)

        assert (self.loss_function
                )  # Please set self.loss_function in the trainer construction.
        loss_function = self.loss_function.cuda(
        ) if hparams.use_gpu else self.loss_function

        all_loss = list(
        )  # List which is returned, containing all loss so that progress is visible.
        all_loss_train = list()
        best_loss = np.nan
        start_epoch = self.total_epoch

        # Compute error before first iteration.
        if hparams.start_with_test:
            self.logger.info('Test epoch [{}/{}]:'.format(
                start_epoch, start_epoch + hparams.epochs))
            loss, loss_features = self.model_handler.test(
                hparams, start_epoch, start_epoch, loss_function)
            all_loss_train.append(
                -1.0)  # Set a placeholder at the train losses.
            all_loss.append(loss)
            best_loss = loss  # Variable to save the current best loss.

        for current_epoch in range(1, hparams.epochs + 1):
            # Increment epoch number.
            self.total_epoch += 1

            # Train one epoch.
            self.logger.info('Train epoch [{}/{}]:'.format(
                self.total_epoch, start_epoch + hparams.epochs))
            train_loss = self.model_handler.train(hparams, self.total_epoch,
                                                  current_epoch, loss_function)
            all_loss_train.append(train_loss)
            if np.isnan(train_loss):
                break

            # Test if requested.
            if self.total_epoch % hparams.epochs_per_test == 0:
                self.logger.info('Test epoch [{}/{}]:'.format(
                    self.total_epoch, start_epoch + hparams.epochs))
                # Compute error on validation set.
                loss, loss_features = self.model_handler.test(
                    hparams, self.total_epoch, current_epoch, loss_function)

                # Save loss in a list which is returned.
                all_loss.append(loss)

                # Stop when loss is NaN. Reloading from checkpoint if necessary.
                if np.isnan(loss):
                    break

                # Save checkpoint if path is given.
                if hparams.out_dir is not None:
                    path_checkpoint = os.path.join(hparams.out_dir,
                                                   hparams.networks_dir,
                                                   hparams.checkpoints_dir)
                    # Check when to save a checkpoint.
                    if hparams.epochs_per_checkpoint > 0 and self.total_epoch % hparams.epochs_per_checkpoint == 0:
                        model_name = "{}-e{}-{}".format(
                            hparams.model_name, self.total_epoch,
                            loss_function)
                        self.model_handler.save_checkpoint(
                            os.path.join(path_checkpoint, model_name),
                            self.total_epoch)
                    # Always save best checkpoint with special name.
                    if loss < best_loss or np.isnan(best_loss):
                        best_loss = loss
                        model_name = hparams.model_name + "-best"
                        self.model_handler.save_checkpoint(
                            os.path.join(path_checkpoint, model_name),
                            self.total_epoch)

                # Run the scheduler if requested.
                if hparams.epochs_per_scheduler_step:
                    if (self.total_epoch if hparams.use_saved_learning_rate else current_epoch)\
                            % hparams.epochs_per_scheduler_step == 0:
                        self.model_handler.run_scheduler(
                            loss, self.total_epoch + 1)

        t_training = timer() - t_start
        self.logger.info('Training time: ' +
                         str(timedelta(seconds=t_training)))
        self.logger.info('Loss progress: ' + ', '.join('{:.4f}'.format(l)
                                                       for l in all_loss))
        self.logger.info('Loss train progress: ' +
                         ', '.join('{:.4f}'.format(l) for l in all_loss_train))

        if hparams.out_dir is not None:
            # Check if best model should be used as final model. Only possible when it was save in out_dir.
            if hparams.use_best_as_final_model:
                best_model_path = os.path.join(hparams.out_dir,
                                               hparams.networks_dir,
                                               hparams.checkpoints_dir,
                                               hparams.model_name + "-best")
                try:
                    self.total_epoch = self.model_handler.load_checkpoint(
                        best_model_path, hparams, hparams.optimiser_args["lr"]
                        if hparams.optimiser_args["lr"] else
                        hparams.learning_rate)
                    if self.model_handler.ema:  # EMA model should be used as best model.
                        self.model_handler.model = self.model_handler.ema.model
                        self.model_handler.ema = None  # Reset this one so that a new one is created for further training.
                        self.logger.info(
                            "Using best EMA model (epoch {}) as final model.".
                            format(self.total_epoch))
                    else:
                        self.logger.info(
                            "Using best (epoch {}) as final model.".format(
                                self.total_epoch))
                except FileNotFoundError:
                    self.logger.warning(
                        "No best model exists yet. Continue with the current one."
                    )

            # Save the model if requested.
            if hparams.save_final_model:
                self.model_handler.save_checkpoint(
                    os.path.join(hparams.out_dir, hparams.networks_dir,
                                 hparams.model_name), self.total_epoch)

        return all_loss, all_loss_train, self.model_handler

    @staticmethod
    def _input_to_str_list(input):
        # Checks for string input first.
        if isinstance(input, str):
            # Check if string is a path by trying to read ids from file.
            try:
                with open(input) as f:
                    id_list = f.readlines()
                # Trim entries in-place.
                id_list[:] = [s.strip(' \t\n\r') for s in id_list]
                return id_list
            except IOError:
                # String is single input id, convert to list.
                return [input]
        # Checks for list or tuple.
        elif isinstance(input, (list, tuple)):
            # Ensure elements are strings.
            return list(map(str, input))
        raise ValueError("Unkown input {} of type {}.".format(
            input, type(input)))

    @staticmethod
    def split_batch(output,
                    hidden,
                    seq_length_output=None,
                    permutation=None,
                    batch_first=False):
        """
        Retrieve output and hidden from batch.

        :param output:             Batched output tensor given by network.
        :param hidden:             Batched hidden tensor given by network.
        :param seq_length_output:  Tuple containing the lengths of all samples in the batch.
        :param permutation:        Permutations previously applied to the batch, which are reverted here.
        :param batch_first:        Batch dimension is first in output.
        :return:                   List of outputs and list of hidden, where each entry corresponds to one sample in the batch.
        """

        # Split the output of the batch.
        return ModelTrainer._split_return_values(output, seq_length_output, permutation, batch_first),\
               ModelTrainer._split_return_values(hidden, seq_length_output, permutation, batch_first)

    @classmethod
    def _split_return_values(cls, input_values, seq_length_output, permutation,
                             batch_first):
        if input_values is None:
            return None

        # Special case for bidirectional layers where the hidden state is a tuple.
        if isinstance(input_values, tuple):
            # If hidden is a tuple of None, return it directly.
            if all(v is None for v in input_values):
                return input_values

            # Split hidden states in their batch dimension.
            tuple_splitted = tuple(
                map(
                    lambda x: cls._split_return_values(
                        x, seq_length_output, permutation, batch_first),
                    input_values))

            # Now sort into each batch.
            return_values = list()
            batch_size = len([t for t in tuple_splitted if t is not None
                              ][0])  # Get batch size from not None element.

            for index in range(batch_size):
                batch = list()
                for element in tuple_splitted:
                    if element is None or (isinstance(element, tuple)
                                           and all(v is None
                                                   for v in element)):
                        batch.append(
                            element)  # Handles None and tuples of None.
                    else:
                        batch.append(element[index])
                return_values.append(tuple(batch))

            return tuple(return_values)

        if not isinstance(input_values, np.ndarray):
            cls.logger.error(
                "Expected numpy tensor but input is of type {}.".format(
                    type(input_values)))
            raise TypeError()

        # Return value is tensor.
        if batch_first:
            return_values = np.split(input_values,
                                     input_values.shape[0],
                                     axis=0)
            return_values = list(
                map(partial(np.squeeze, axis=0), return_values))
        else:
            return_values = np.split(input_values,
                                     input_values.shape[1],
                                     axis=1)
            return_values = list(
                map(partial(np.squeeze, axis=1), return_values))

        if seq_length_output is not None and len(seq_length_output) > 1:
            for idx in range(len(return_values)):
                return_values[idx] = return_values[
                    idx][:seq_length_output[idx]]

        if permutation is not None:
            return_values_unsorted = return_values.copy()
            for org_index, current_index in enumerate(permutation):
                return_values_unsorted[current_index] = return_values[
                    org_index]
            return_values = return_values_unsorted

        return return_values

    def forward(self, hparams, ids_input):
        """
        Forward all given ids through the network in batches of hparams.batch_size_val.

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, or one id.or None.
        :return:               (Dictionary of network outputs, dictionary of post-processed (by self.OutputGen) network outputs)
        """
        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.
        id_list = ModelTrainer._input_to_str_list(ids_input)

        self.logger.info("Start forwarding [{0}]".format(", ".join(
            str(i) for i in id_list)))
        t_start = timer()
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_val,
            load_target=False,
            synth=False,
            benchmark=False,
            gen_figure=False)
        t_training = timer() - t_start
        self.logger.info('Forwarding time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_output, model_output_post

    def synth(self, hparams, ids_input):
        """
        Synthesise all given ids with the self.synthesize function.

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, or one id.
        :return:               (Dictionary of network outputs, dictionary of post-processed (by self.OutputGen) network outputs)
        """

        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.
        assert (
            hparams.synth_dir is not None
        )  # Directory to store the generated audio files has to be set.
        makedirs_safe(hparams.synth_dir)
        id_list = ModelTrainer._input_to_str_list(ids_input)

        self.logger.info("Start synthesising [{0}]".format(", ".join(
            str(i) for i in id_list)))
        t_start = timer()
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_synth,
            load_target=False,
            synth=True,
            benchmark=False,
            gen_figure=hparams.synth_gen_figure)
        t_training = timer() - t_start
        self.logger.info('Synthesis time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_output, model_output_post

    def gen_figure(self, hparams, ids_input):
        """
        Generate figures for all given ids with the self.gen_figure_from_output function (has to be implemented).

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, or one id.
        :return:               (Dictionary of network outputs, dictionary of post-processed (by self.OutputGen) network outputs)
        """

        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.
        id_list = ModelTrainer._input_to_str_list(ids_input)

        self.logger.info("Start generating figures for [{0}]".format(", ".join(
            str(i) for i in id_list)))
        t_start = timer()
        model_output, model_output_post = self._forward_batched(
            hparams,
            id_list,
            hparams.batch_size_gen_figure,
            synth=False,
            benchmark=False,
            gen_figure=True)
        t_training = timer() - t_start
        self.logger.info('Figure generation time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_output, model_output_post

    def benchmark(self, hparams, ids_input=None):
        """
        Benchmark the currently loaded model using the self.compute_score function (has to be implemented).

        :param hparams:        Hyper-parameter container.
        :param ids_input:      Can be full path to file with ids, list of ids, one id, or None.
                               If ids_inputs=None benchmark on test set if not None, otherwise on validation set.
        :return:               Score(s).
        """

        assert (callable(getattr(self, 'compute_score', None))
                )  # Function has to be implemented for this trainer.
        assert (self.model_handler
                is not None)  # Check if trainer.init() was called before.

        # Select test or validation set when ids are not given explicitly.
        if ids_input is None:
            if self.id_list_test is not None and len(self.id_list_test) > 0:
                id_list = sorted(self.id_list_test)
                self.logger.info(
                    "Start benchmark on test set ({}): [{}]".format(
                        len(id_list), ", ".join(str(i) for i in id_list)))
            elif self.id_list_val is not None and len(self.id_list_val) > 0:
                id_list = sorted(self.id_list_val)
                self.logger.info(
                    "Start benchmark on validation set ({}): [{}]".format(
                        len(id_list), ", ".join(str(i) for i in id_list)))
            else:
                raise ValueError(
                    "No id list can be selected for benchmark, because non was given as parameter "
                    "and test and validation set are empty.")
        else:
            id_list = ModelTrainer._input_to_str_list(ids_input)
            self.logger.info(
                "Start benchmark on given input ({}): [{}]".format(
                    len(id_list), ", ".join(str(i) for i in id_list)))

        t_start = timer()
        model_scores = self._forward_batched(hparams,
                                             id_list,
                                             hparams.batch_size_benchmark,
                                             synth=False,
                                             benchmark=True,
                                             gen_figure=False)
        t_training = timer() - t_start
        self.logger.info('Benchmark time for {} sample(s): {}'.format(
            len(id_list), timedelta(seconds=t_training)))

        return model_scores

    def _forward_batched(self,
                         hparams,
                         id_list,
                         batch_size,
                         load_target=True,
                         synth=False,
                         benchmark=False,
                         gen_figure=False):
        """
        Forward the features for the given ids in batches through the network.

        :param hparams:               Hyper-parameter container.
        :param id_list:               A list of ids for which the features are accessible by the self.InputGen object.
        :param batch_size:            Max size of a chunk of ids forwarded.
        :param load_target:           Give the target to the model when forwarded (used in teacher forcing).
        :param synth:                 Use the self.synthesize method to generate audio.
        :param benchmark:             Benchmark the given ids with the self.compute_score function.
        :param gen_figure:            Generate figures with the self.gen_figure_from_output function.
        :return:                      (Dictionary of outputs, dictionary of post-processed (by self.OutputGen) outputs)
        """

        self.logger.info("Get model outputs as batches of size {}.".format(
            min(batch_size, len(id_list))))
        dict_outputs = dict()
        dict_outputs_post = dict()
        dict_hiddens = dict()

        for batch_index in range(0, len(id_list), batch_size):
            batch_id_list = id_list[batch_index:batch_index + batch_size]

            inputs = list()
            for id_name in batch_id_list:
                # Load preprocessed sample and add it to the inputs with target value.
                inputs.append(
                    self.dataset_train.getitem_by_name(
                        id_name, load_target))  # No length check here.

            if self.batch_collate_fn is not None:
                batch_input_labels, batch_target_labels, seq_length_inputs, seq_length_output, *_, permutation = self.batch_collate_fn(
                    inputs,
                    common_divisor=hparams.num_gpus,
                    batch_first=hparams.batch_first)
            else:
                batch_input_labels, batch_target_labels, seq_length_inputs, seq_length_output, *_, permutation = self.model_handler.prepare_batch(
                    inputs,
                    common_divisor=hparams.num_gpus,
                    batch_first=hparams.batch_first)

            # Run forward pass of model.
            nn_output, nn_hidden = self.model_handler.forward(
                batch_input_labels, hparams, seq_length_inputs)

            # Retrieve output from batch.
            if self.batch_decollate_fn is not None:
                outputs, hiddens = self.batch_decollate_fn(
                    nn_output,
                    nn_hidden,
                    seq_length_output,
                    permutation,
                    batch_first=hparams.batch_first)
            else:
                outputs, hiddens = self.split_batch(
                    nn_output,
                    nn_hidden,
                    seq_length_output,
                    permutation,
                    batch_first=hparams.batch_first)

            # Post-process samples and generate a figure if requested.
            for idx, id_name in enumerate(batch_id_list):
                dict_outputs[id_name] = outputs[idx]
                dict_hiddens[
                    id_name] = hiddens[idx] if hiddens is not None else None

                # If output is a list or tuple use only the first element for post-processing.
                if isinstance(outputs[idx], tuple) or isinstance(
                        outputs[idx], list):
                    # Generate a figure if requested.
                    if gen_figure:
                        self.gen_figure_from_output(
                            id_name, outputs[idx][0],
                            hiddens[idx] if hiddens is not None else None,
                            hparams)
                    dict_outputs_post[
                        id_name] = self.dataset_train.postprocess_sample(
                            outputs[idx][0])
                else:
                    # Generate a figure if requested.
                    if gen_figure:
                        self.gen_figure_from_output(
                            id_name, outputs[idx],
                            hiddens[idx] if hiddens is not None else None,
                            hparams)
                    dict_outputs_post[
                        id_name] = self.dataset_train.postprocess_sample(
                            outputs[idx])

        if benchmark:
            # Implementation of compute_score is checked in benchmark function.
            return self.compute_score(dict_outputs_post, dict_hiddens, hparams)
        if synth:
            self.synthesize(id_list, dict_outputs_post, hparams)

        return dict_outputs, dict_outputs_post

    def gen_figure_from_output(self, id_name, output, hidden, hparams):
        raise NotImplementedError(
            "Class {} doesn't implement gen_figure_from_output(id_name, output, hidden, hparams)"
            .format(self.__class__.__name__))

    def synth_ref(self, hparams, file_id_list):
        if hparams.synth_vocoder == "WORLD":
            world_dir = hparams.world_dir if hasattr(hparams, "world_dir") and hparams.world_dir is not None\
                                          else os.path.join(self.OutputGen.dir_labels, self.dir_extracted_acoustic_features)
            Synthesiser.synth_ref(hparams, file_id_list, world_dir)
            hparams.synth_file_suffix += str(hparams.num_coded_sps) + 'sp'
        else:
            Synthesiser.synth_ref(hparams, file_id_list)

    def synthesize(self, file_id_list, synth_output, hparams):

        # Create speaker subdirectories if necessary.
        for id_name in file_id_list:
            path_split = os.path.split(id_name)
            if len(path_split) > 2:
                makedirs_safe(os.path.join(hparams.synth_dir,
                                           *path_split[:-1]))

        if hparams.synth_vocoder == "WORLD":
            Synthesiser.run_world_synth(synth_output, hparams)
        # elif hparams.synth_vocoder == "STRAIGHT":  # Add further vocoders here.

        elif hparams.synth_vocoder == "r9y9wavenet_mulaw_16k_world_feats_English":
            Synthesiser.run_r9y9wavenet_mulaw_world_feats_synth(
                synth_output, hparams)

        elif hparams.synth_vocoder == "raw":
            # The features in the synth_output dictionary are raw waveforms and can be written directly to the file.
            Synthesiser.run_raw_synth(synth_output, hparams)

        elif hparams.synth_vocoder == "80_SSRN_English_GL":
            # Use a pre-trained spectrogram super resolution network for English and Griffin-Lim.
            # The features in the synth_output should be mfbanks.
            raise NotImplementedError()  # TODO

        elif hparams.synth_vocoder == "r9y9wavenet":
            # Synthesise with a pre-trained r9y9 WaveNet. The hyper-parameters have to match the model.
            Synthesiser.run_wavenet_vocoder(synth_output, hparams)