示例#1
0
    def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):  # pylint: disable=no-self-use
        alignments = outputs["alignments"]
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        d_vectors = batch["d_vectors"]
        speaker_ids = batch["speaker_ids"]

        # model runs reverse flow to predict spectrograms
        pred_outputs = self.inference(
            text_input[:1],
            aux_input={
                "x_lengths": text_lengths[:1],
                "d_vectors": d_vectors,
                "speaker_ids": speaker_ids
            },
        )
        model_outputs = pred_outputs["model_outputs"]

        pred_spec = model_outputs[0].data.cpu().numpy()
        gt_spec = mel_input[0].data.cpu().numpy()
        align_img = alignments[0].data.cpu().numpy()

        figures = {
            "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
            "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
            "alignment": plot_alignment(align_img, output_fig=False),
        }

        # Sample audio
        train_audio = ap.inv_melspectrogram(pred_spec.T)
        return figures, {"audio": train_audio}
示例#2
0
    def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):  # pylint: disable=no-self-use
        model_outputs = outputs["model_outputs"]
        alignments = outputs["alignments"]
        mel_input = batch["mel_input"]

        pred_spec = model_outputs[0].data.cpu().numpy()
        gt_spec = mel_input[0].data.cpu().numpy()
        align_img = alignments[0].data.cpu().numpy()

        figures = {
            "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
            "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
            "alignment": plot_alignment(align_img, output_fig=False),
        }

        # Sample audio
        train_audio = ap.inv_melspectrogram(pred_spec.T)
        return figures, {"audio": train_audio}
示例#3
0
    def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):  # pylint: disable=no-self-use
        model_outputs = outputs["model_outputs"]
        alignments = outputs["alignments"]
        mel_input = batch["mel_input"]

        pred_spec = model_outputs[0].data.cpu().numpy()
        gt_spec = mel_input[0].data.cpu().numpy()
        align_img = alignments[0].data.cpu().numpy()

        figures = {
            "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
            "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
            "alignment": plot_alignment(align_img, output_fig=False),
        }

        # plot pitch figures
        if self.args.use_pitch:
            pitch = batch["pitch"]
            pitch_avg_expanded, _ = self.expand_encoder_outputs(
                outputs["pitch_avg"], outputs["durations"], outputs["x_mask"],
                outputs["y_mask"])
            pitch = pitch[0, 0].data.cpu().numpy()
            # TODO: denormalize before plotting
            pitch = abs(pitch)
            pitch_avg_expanded = abs(pitch_avg_expanded[0,
                                                        0]).data.cpu().numpy()
            pitch_figures = {
                "pitch_ground_truth":
                plot_pitch(pitch, gt_spec, ap, output_fig=False),
                "pitch_avg_predicted":
                plot_pitch(pitch_avg_expanded, pred_spec, ap,
                           output_fig=False),
            }
            figures.update(pitch_figures)

        # plot the attention mask computed from the predicted durations
        if "attn_durations" in outputs:
            alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
            figures["alignment_hat"] = plot_alignment(alignments_hat.T,
                                                      output_fig=False)

        # Sample audio
        train_audio = ap.inv_melspectrogram(pred_spec.T)
        return figures, {"audio": train_audio}
示例#4
0
文件: tacotron2.py 项目: gerazov/TTS
    def train_log(self, ap: AudioProcessor, batch: dict,
                  outputs: dict) -> Tuple[Dict, Dict]:
        postnet_outputs = outputs["model_outputs"]
        alignments = outputs["alignments"]
        alignments_backward = outputs["alignments_backward"]
        mel_input = batch["mel_input"]

        pred_spec = postnet_outputs[0].data.cpu().numpy()
        gt_spec = mel_input[0].data.cpu().numpy()
        align_img = alignments[0].data.cpu().numpy()

        figures = {
            "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
            "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
            "alignment": plot_alignment(align_img, output_fig=False),
        }

        if self.bidirectional_decoder or self.double_decoder_consistency:
            figures["alignment_backward"] = plot_alignment(
                alignments_backward[0].data.cpu().numpy(), output_fig=False)

        # Sample audio
        train_audio = ap.inv_melspectrogram(pred_spec.T)
        return figures, {"audio": train_audio}
示例#5
0
class TestTTSDataset(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_loader_iter = 4
        self.ap = AudioProcessor(**c.audio)

    def _create_dataloader(self, batch_size, r, bgs):
        items = ljspeech(c.data_path, "metadata.csv")
        dataset = TTSDataset.MyDataset(
            r,
            c.text_cleaner,
            compute_linear_spec=True,
            ap=self.ap,
            meta_data=items,
            tp=c.characters,
            batch_group_size=bgs,
            min_seq_len=c.min_seq_len,
            max_seq_len=float("inf"),
            use_phonemes=False,
        )
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=dataset.collate_fn,
            drop_last=True,
            num_workers=c.num_loader_workers,
        )
        return dataloader, dataset

    def test_loader(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(2, c.r, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                neg_values = text_input[text_input < 0]
                check_count = len(neg_values)
                assert check_count == 0, " !! Negative values in text_input: {}".format(
                    check_count)
                # TODO: more assertion here
                assert isinstance(speaker_name[0], str)
                assert linear_input.shape[0] == c.batch_size
                assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
                assert mel_input.shape[0] == c.batch_size
                assert mel_input.shape[2] == c.audio["num_mels"]
                # check normalization ranges
                if self.ap.symmetric_norm:
                    assert mel_input.max() <= self.ap.max_norm
                    assert mel_input.min() >= -self.ap.max_norm  # pylint: disable=invalid-unary-operand-type
                    assert mel_input.min() < 0
                else:
                    assert mel_input.max() <= self.ap.max_norm
                    assert mel_input.min() >= 0

    def test_batch_group_shuffle(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(2, c.r, 16)
            last_length = 0
            frames = dataset.items
            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                avg_length = mel_lengths.numpy().mean()
                assert avg_length >= last_length
            dataloader.dataset.sort_items()
            is_items_reordered = False
            for idx, item in enumerate(dataloader.dataset.items):
                if item != frames[idx]:
                    is_items_reordered = True
                    break
            assert is_items_reordered

    def test_padding_and_spec(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(1, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                # check mel_spec consistency
                wav = np.asarray(self.ap.load_wav(item_idx[0]),
                                 dtype=np.float32)
                mel = self.ap.melspectrogram(wav).astype("float32")
                mel = torch.FloatTensor(mel).contiguous()
                mel_dl = mel_input[0]
                # NOTE: Below needs to check == 0 but due to an unknown reason
                # there is a slight difference between two matrices.
                # TODO: Check this assert cond more in detail.
                assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T -
                                                             mel_dl).max()

                # check mel-spec correctness
                mel_spec = mel_input[0].cpu().numpy()
                wav = self.ap.inv_melspectrogram(mel_spec.T)
                self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
                shutil.copy(item_idx[0],
                            OUTPATH + "/mel_target_dataloader.wav")

                # check linear-spec
                linear_spec = linear_input[0].cpu().numpy()
                wav = self.ap.inv_spectrogram(linear_spec.T)
                self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
                shutil.copy(item_idx[0],
                            OUTPATH + "/linear_target_dataloader.wav")

                # check the last time step to be zero padded
                assert linear_input[0, -1].sum() != 0
                assert linear_input[0, -2].sum() != 0
                assert mel_input[0, -1].sum() != 0
                assert mel_input[0, -2].sum() != 0
                assert stop_target[0, -1] == 1
                assert stop_target[0, -2] == 0
                assert stop_target.sum() == 1
                assert len(mel_lengths.shape) == 1
                assert mel_lengths[0] == linear_input[0].shape[0]
                assert mel_lengths[0] == mel_input[0].shape[0]

            # Test for batch size 2
            dataloader, dataset = self._create_dataloader(2, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data[0]
                text_lengths = data[1]
                speaker_name = data[2]
                linear_input = data[3]
                mel_input = data[4]
                mel_lengths = data[5]
                stop_target = data[6]
                item_idx = data[7]

                if mel_lengths[0] > mel_lengths[1]:
                    idx = 0
                else:
                    idx = 1

                # check the first item in the batch
                assert linear_input[idx, -1].sum() != 0
                assert linear_input[idx, -2].sum() != 0, linear_input
                assert mel_input[idx, -1].sum() != 0
                assert mel_input[idx, -2].sum() != 0, mel_input
                assert stop_target[idx, -1] == 1
                assert stop_target[idx, -2] == 0
                assert stop_target[idx].sum() == 1
                assert len(mel_lengths.shape) == 1
                assert mel_lengths[idx] == mel_input[idx].shape[0]
                assert mel_lengths[idx] == linear_input[idx].shape[0]

                # check the second itme in the batch
                assert linear_input[1 - idx, -1].sum() == 0
                assert mel_input[1 - idx, -1].sum() == 0
                assert stop_target[1, mel_lengths[1] - 1] == 1
                assert stop_target[1, mel_lengths[1]:].sum() == 0
                assert len(mel_lengths.shape) == 1
示例#6
0
class TestAudio(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ap = AudioProcessor(**conf)

    def test_audio_synthesis(self):
        """1. load wav
        2. set normalization parameters
        3. extract mel-spec
        4. invert to wav and save the output
        """
        print(" > Sanity check for the process wav -> mel -> wav")

        def _test(max_norm, signal_norm, symmetric_norm, clip_norm):
            self.ap.max_norm = max_norm
            self.ap.signal_norm = signal_norm
            self.ap.symmetric_norm = symmetric_norm
            self.ap.clip_norm = clip_norm
            wav = self.ap.load_wav(WAV_FILE)
            mel = self.ap.melspectrogram(wav)
            wav_ = self.ap.inv_melspectrogram(mel)
            file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav".format(
                max_norm, signal_norm, symmetric_norm, clip_norm
            )
            print(" | > Creating wav file at : ", file_name)
            self.ap.save_wav(wav_, OUT_PATH + file_name)

        # maxnorm = 1.0
        _test(1.0, False, False, False)
        _test(1.0, True, False, False)
        _test(1.0, True, True, False)
        _test(1.0, True, False, True)
        _test(1.0, True, True, True)
        # maxnorm = 4.0
        _test(4.0, False, False, False)
        _test(4.0, True, False, False)
        _test(4.0, True, True, False)
        _test(4.0, True, False, True)
        _test(4.0, True, True, True)

    def test_normalize(self):
        """Check normalization and denormalization for range values and consistency"""
        print(" > Testing normalization and denormalization.")
        wav = self.ap.load_wav(WAV_FILE)
        wav = self.ap.sound_norm(wav)  # normalize audio to get abetter normalization range below.
        self.ap.signal_norm = False
        x = self.ap.melspectrogram(wav)
        x_old = x

        self.ap.signal_norm = True
        self.ap.symmetric_norm = False
        self.ap.clip_norm = False
        self.ap.max_norm = 4.0
        x_norm = self.ap.normalize(x)
        print(
            f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}"
        )
        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
        assert x_norm.min() >= 0 - 1, x_norm.min()
        # check denorm.
        x_ = self.ap.denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = False
        self.ap.clip_norm = True
        self.ap.max_norm = 4.0
        x_norm = self.ap.normalize(x)
        print(
            f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}"
        )

        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= 0, x_norm.min()
        # check denorm.
        x_ = self.ap.denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = True
        self.ap.clip_norm = False
        self.ap.max_norm = 4.0
        x_norm = self.ap.normalize(x)
        print(
            f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}"
        )

        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max()
        assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min()  # pylint: disable=invalid-unary-operand-type
        assert x_norm.min() <= 0, x_norm.min()
        # check denorm.
        x_ = self.ap.denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = True
        self.ap.clip_norm = True
        self.ap.max_norm = 4.0
        x_norm = self.ap.normalize(x)
        print(
            f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}"
        )

        assert (x_old - x).sum() == 0
        # check value range
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= -self.ap.max_norm, x_norm.min()  # pylint: disable=invalid-unary-operand-type
        assert x_norm.min() <= 0, x_norm.min()
        # check denorm.
        x_ = self.ap.denormalize(x_norm)
        assert (x - x_).sum() < 1e-3, (x - x_).mean()

        self.ap.signal_norm = True
        self.ap.symmetric_norm = False
        self.ap.max_norm = 1.0
        x_norm = self.ap.normalize(x)
        print(
            f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}"
        )

        assert (x_old - x).sum() == 0
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= 0, x_norm.min()
        x_ = self.ap.denormalize(x_norm)
        assert (x - x_).sum() < 1e-3

        self.ap.signal_norm = True
        self.ap.symmetric_norm = True
        self.ap.max_norm = 1.0
        x_norm = self.ap.normalize(x)
        print(
            f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} --  {x_norm.min()}"
        )

        assert (x_old - x).sum() == 0
        assert x_norm.max() <= self.ap.max_norm, x_norm.max()
        assert x_norm.min() >= -self.ap.max_norm, x_norm.min()  # pylint: disable=invalid-unary-operand-type
        assert x_norm.min() < 0, x_norm.min()
        x_ = self.ap.denormalize(x_norm)
        assert (x - x_).sum() < 1e-3

    def test_scaler(self):
        scaler_stats_path = os.path.join(get_tests_input_path(), "scale_stats.npy")
        conf.stats_path = scaler_stats_path
        conf.preemphasis = 0.0
        conf.do_trim_silence = True
        conf.signal_norm = True

        ap = AudioProcessor(**conf)
        mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path)
        ap.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)

        self.ap.signal_norm = False
        self.ap.preemphasis = 0.0

        # test scaler forward and backward transforms
        wav = self.ap.load_wav(WAV_FILE)
        mel_reference = self.ap.melspectrogram(wav)
        mel_norm = ap.melspectrogram(wav)
        mel_denorm = ap.denormalize(mel_norm)
        assert abs(mel_reference - mel_denorm).max() < 1e-4

    def test_compute_f0(self):  # pylint: disable=no-self-use
        ap = AudioProcessor(**conf)
        wav = ap.load_wav(WAV_FILE)
        pitch = ap.compute_f0(wav)
        mel = ap.melspectrogram(wav)
        assert pitch.shape[0] == mel.shape[1]
示例#7
0
class TestTTSDataset(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_loader_iter = 4
        self.ap = AudioProcessor(**c.audio)

    def _create_dataloader(self, batch_size, r, bgs, start_by_longest=False):

        # load dataset
        meta_data_train, meta_data_eval = load_tts_samples(dataset_config,
                                                           eval_split=True,
                                                           eval_split_size=0.2)
        items = meta_data_train + meta_data_eval

        tokenizer, _ = TTSTokenizer.init_from_config(c)
        dataset = TTSDataset(
            outputs_per_step=r,
            compute_linear_spec=True,
            return_wav=True,
            tokenizer=tokenizer,
            ap=self.ap,
            samples=items,
            batch_group_size=bgs,
            min_text_len=c.min_text_len,
            max_text_len=c.max_text_len,
            min_audio_len=c.min_audio_len,
            max_audio_len=c.max_audio_len,
            start_by_longest=start_by_longest,
        )
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=dataset.collate_fn,
            drop_last=True,
            num_workers=c.num_loader_workers,
        )
        return dataloader, dataset

    def test_loader(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(1, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                text_input = data["token_id"]
                _ = data["token_id_lengths"]
                speaker_name = data["speaker_names"]
                linear_input = data["linear"]
                mel_input = data["mel"]
                mel_lengths = data["mel_lengths"]
                _ = data["stop_targets"]
                _ = data["item_idxs"]
                wavs = data["waveform"]

                neg_values = text_input[text_input < 0]
                check_count = len(neg_values)

                # check basic conditions
                self.assertEqual(check_count, 0)
                self.assertEqual(linear_input.shape[0], mel_input.shape[0],
                                 c.batch_size)
                self.assertEqual(linear_input.shape[2],
                                 self.ap.fft_size // 2 + 1)
                self.assertEqual(mel_input.shape[2], c.audio["num_mels"])
                self.assertEqual(wavs.shape[1],
                                 mel_input.shape[1] * c.audio.hop_length)
                self.assertIsInstance(speaker_name[0], str)

                # make sure that the computed mels and the waveform match and correctly computed
                mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy())
                # remove padding in mel-spectrogram
                mel_dataloader = mel_input[0].T.numpy()[:, :mel_lengths[0]]
                # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding
                mel_new = mel_new[:, :mel_lengths[0]]
                ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length)
                mel_diff = (mel_new[:, :mel_input.shape[1]] -
                            mel_input[0].T.numpy())[:, 0:ignore_seg]
                self.assertLess(abs(mel_diff.sum()), 1e-5)

                # check normalization ranges
                if self.ap.symmetric_norm:
                    self.assertLessEqual(mel_input.max(), self.ap.max_norm)
                    self.assertGreaterEqual(
                        mel_input.min(),
                        -self.ap.max_norm  # pylint: disable=invalid-unary-operand-type
                    )
                    self.assertLess(mel_input.min(), 0)
                else:
                    self.assertLessEqual(mel_input.max(), self.ap.max_norm)
                    self.assertGreaterEqual(mel_input.min(), 0)

    def test_batch_group_shuffle(self):
        if ok_ljspeech:
            dataloader, dataset = self._create_dataloader(2, c.r, 16)
            last_length = 0
            frames = dataset.samples
            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                mel_lengths = data["mel_lengths"]
                avg_length = mel_lengths.numpy().mean()
            dataloader.dataset.preprocess_samples()
            is_items_reordered = False
            for idx, item in enumerate(dataloader.dataset.samples):
                if item != frames[idx]:
                    is_items_reordered = True
                    break
            self.assertGreaterEqual(avg_length, last_length)
            self.assertTrue(is_items_reordered)

    def test_start_by_longest(self):
        """Test start_by_longest option.

        Ther first item of the fist batch must be longer than all the other items.
        """
        if ok_ljspeech:
            dataloader, _ = self._create_dataloader(2, c.r, 0, True)
            dataloader.dataset.preprocess_samples()
            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                mel_lengths = data["mel_lengths"]
                if i == 0:
                    max_len = mel_lengths[0]
                print(mel_lengths)
                self.assertTrue(all(max_len >= mel_lengths))

    def test_padding_and_spectrograms(self):
        def check_conditions(idx, linear_input, mel_input, stop_target,
                             mel_lengths):
            self.assertNotEqual(linear_input[idx, -1].sum(),
                                0)  # check padding
            self.assertNotEqual(linear_input[idx, -2].sum(), 0)
            self.assertNotEqual(mel_input[idx, -1].sum(), 0)
            self.assertNotEqual(mel_input[idx, -2].sum(), 0)
            self.assertEqual(stop_target[idx, -1], 1)
            self.assertEqual(stop_target[idx, -2], 0)
            self.assertEqual(stop_target[idx].sum(), 1)
            self.assertEqual(len(mel_lengths.shape), 1)
            self.assertEqual(mel_lengths[idx], linear_input[idx].shape[0])
            self.assertEqual(mel_lengths[idx], mel_input[idx].shape[0])

        if ok_ljspeech:
            dataloader, _ = self._create_dataloader(1, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                linear_input = data["linear"]
                mel_input = data["mel"]
                mel_lengths = data["mel_lengths"]
                stop_target = data["stop_targets"]
                item_idx = data["item_idxs"]

                # check mel_spec consistency
                wav = np.asarray(self.ap.load_wav(item_idx[0]),
                                 dtype=np.float32)
                mel = self.ap.melspectrogram(wav).astype("float32")
                mel = torch.FloatTensor(mel).contiguous()
                mel_dl = mel_input[0]
                # NOTE: Below needs to check == 0 but due to an unknown reason
                # there is a slight difference between two matrices.
                # TODO: Check this assert cond more in detail.
                self.assertLess(abs(mel.T - mel_dl).max(), 1e-5)

                # check mel-spec correctness
                mel_spec = mel_input[0].cpu().numpy()
                wav = self.ap.inv_melspectrogram(mel_spec.T)
                self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav")
                shutil.copy(item_idx[0],
                            OUTPATH + "/mel_target_dataloader.wav")

                # check linear-spec
                linear_spec = linear_input[0].cpu().numpy()
                wav = self.ap.inv_spectrogram(linear_spec.T)
                self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav")
                shutil.copy(item_idx[0],
                            OUTPATH + "/linear_target_dataloader.wav")

                # check the outputs
                check_conditions(0, linear_input, mel_input, stop_target,
                                 mel_lengths)

            # Test for batch size 2
            dataloader, _ = self._create_dataloader(2, 1, 0)

            for i, data in enumerate(dataloader):
                if i == self.max_loader_iter:
                    break
                linear_input = data["linear"]
                mel_input = data["mel"]
                mel_lengths = data["mel_lengths"]
                stop_target = data["stop_targets"]
                item_idx = data["item_idxs"]

                # set id to the longest sequence in the batch
                if mel_lengths[0] > mel_lengths[1]:
                    idx = 0
                else:
                    idx = 1

                # check the longer item in the batch
                check_conditions(idx, linear_input, mel_input, stop_target,
                                 mel_lengths)

                # check the other item in the batch
                self.assertEqual(linear_input[1 - idx, -1].sum(), 0)
                self.assertEqual(mel_input[1 - idx, -1].sum(), 0)
                self.assertEqual(stop_target[1, mel_lengths[1] - 1], 1)
                self.assertEqual(stop_target[1, mel_lengths[1]:].sum(),
                                 stop_target.shape[1] - mel_lengths[1])
                self.assertEqual(len(mel_lengths.shape), 1)