예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input_dir',
        help='Absolute path to base preprocessed data directory'),
    parser.add_argument('--output_dir',
                        help='Absolute path to the base output directory')
    parser.add_argument(
        '--dataset_name',
        help='The given dataset has to exist in the given input directory')
    parser.add_argument(
        '--model_name',
        help=
        'name of model to be trained on, defaults to main. This is just used for file-keeping.',
        default='main')
    parser.add_argument(
        '--hparams',
        default='',
        help=
        'Hyperparameter overrides as a comma-separated list of name=value pairs'
    )
    parser.add_argument('--restore_step',
                        type=int,
                        help='Global step to restore from checkpoint.')
    parser.add_argument('--summary_interval',
                        type=int,
                        default=100,
                        help='Steps between running summary ops.')
    parser.add_argument('--checkpoint_interval',
                        type=int,
                        default=1000,
                        help='Steps between writing checkpoints.')
    parser.add_argument('--msg_interval',
                        type=int,
                        default=100,
                        help='Interval of general training messages')
    # used for broadcasting training updates to slack.
    parser.add_argument('--slack_url',
                        help='Slack webhook URL to get periodic reports.')
    parser.add_argument('--tf_log_level',
                        type=int,
                        default=1,
                        help='Tensorflow C++ log level.')
    args = parser.parse_args()

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
    args.in_dir = os.path.join(args.input_dir, args.dataset_name)
    args.out_dir = os.path.join(args.output_dir, args.model_name)
    args.log_dir = os.path.join(args.out_dir, 'logs')
    args.meta_dir = os.path.join(args.out_dir, 'meta')
    args.sample_dir = os.path.join(args.out_dir, 'samples')
    # create the output directories if needed
    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.meta_dir, exist_ok=True)
    os.makedirs(args.sample_dir, exist_ok=True)
    hparams.parse(args.hparams)
    model = Tacotron(hparams)
    model.train(args)
예제 #2
0
def create_model(is_training=True):
    encoder = Encoder()
    decoder = Decoder()
    postnet = PostNet()
    post_cbhg = PostCBHG()

    model = Tacotron(encoder=encoder,
                     decoder=decoder,
                     postnet=postnet,
                     post_cbhg=post_cbhg)
    if is_training:
        model.train()
    else:
        model.eval()
    return model
예제 #3
0
class TacotronTrainer:
    TRAIN_STAGE = 'train'
    VAL_STAGE = 'val'
    VERSION_FORMAT = 'VERSION_{}'
    MODEL_SAVE_FORMAT = 'version_{version:03}_model_{step:010}.pth'

    def __init__(self,
                 batch_size: int = 32,
                 num_epoch: int = 100,
                 train_split: float = 0.9,
                 log_interval: int = 1000,
                 log_audio_factor: int = 5,
                 lr: float = 0.001,
                 num_data: int = None,
                 log_root: str = './tb_logs',
                 save_root: str = './checkpoints',
                 num_workers: int = 4,
                 version: int = None,
                 num_test_samples: int = 5):
        """
        Initialize tacotron trainer
        Args:
            batch_size: batch size
            num_epoch: total number of epochs to train
            train_split: train ratio of train-val split
            log_interval: interval for test sample logging to tensorboard in
                epoch unit
            log_audio_factor: number of log_interval for logging audio
                which requires quite a lot of overhead
            num_data: number of datapoints to load in the dataset
            log_root: root directory for the tensorboard logging
            save_root: root directory for saving model
            num_workers: number of workers for dataloader
            version: version of training
            num_test_samples: number of test samples to generate
                for each logging
        """
        if not os.path.exists(log_root):
            os.makedirs(log_root)

        if not os.path.exists(save_root):
            os.makedirs(save_root)

        is_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if is_cuda else 'cpu')

        self.train_split = train_split

        self.epoch_num = num_epoch

        self.splitted_dataset = self.__split_dataset(
            TorchLJSpeechDataset(num_data=num_data))

        self.dataloaders = self.__get_dataloaders(
            batch_size, num_workers=num_workers)

        self.tacotron = Tacotron()
        self.tacotron.to(self.device)

        self.loss = TacotronLoss()
        self.optimizer = Adam(self.tacotron.parameters(), lr=lr)
        self.lr_scheduler = StepLR(
            optimizer=self.optimizer,
            step_size=10000,
            gamma=0.9)

        if version is None:
            versions = os.listdir(log_root)
            if not versions:
                self.version = 0
            else:
                self.version = max([int(ver[-1]) for ver in versions]) + 1

        log_dir = os.path.join(
            log_root, self.VERSION_FORMAT.format(self.version))
        if os.path.exists(log_dir):
            os.remove(log_dir)

        self.logger = SummaryWriter(log_dir)

        self.save_root = save_root

        self.log_interval = log_interval
        self.log_audio_factor = log_audio_factor

        self.global_step = 0
        self.running_count = {self.TRAIN_STAGE: 0,
                              self.VAL_STAGE: 0}
        self.running_loss = {self.TRAIN_STAGE: 0,
                             self.VAL_STAGE: 0}

        self.sample_indices = list(range(num_test_samples))

    def fit_from_checkpoint(self, checkpoint_file: str):
        self.tacotron.load(checkpoint_file, self.device)
        self.fit()

    def fit(self):
        for epoch in tqdm.tqdm(range(self.epoch_num),
                               total=self.epoch_num,
                               desc='Epoch'):
            self.__run_epoch(epoch)

    def __run_epoch(self, epoch: int):
        # reset running loss and count after each epoch
        self.__reset_loss()
        self.__reset_count()

        for stage, dataloader in self.dataloaders.items():
            prog_bar = tqdm.tqdm(dataloader,
                                 desc=f'{stage.capitalize()} in progress',
                                 total=len(dataloader))
            for batch in dataloader:
                self.__run_step(batch, stage, prog_bar)

        # epoch vs global step
        self.logger.add_scalar('epoch', epoch, global_step=self.global_step)

        # add loss to logger
        loss_dict = {stage: self.__calculate_mean_loss(stage)
                     for stage in self.running_loss}
        self.logger.add_scalars('loss', loss_dict, global_step=epoch)


    def __run_step(self, batch: TorchLJSpeechBatch, stage: str,
                   prog_bar: tqdm.tqdm):
        if stage == self.TRAIN_STAGE:
            self.tacotron.train()
            self.optimizer.zero_grad()
        else:
            self.tacotron.eval()

        batch = batch.to(self.device)

        output = self.tacotron.forward_train(batch)
        loss_val = self.loss(batch.mel_spec, output.pred_mel_spec,
                             batch.lin_spec, output.pred_lin_spec)

        self.running_loss[stage] += loss_val.item() * batch.mel_spec.size(0)
        self.running_count[stage] += batch.mel_spec.size(0)

        if stage == self.TRAIN_STAGE:
            loss_val.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            if self.global_step % self.log_interval == 0:
                self.logger.add_scalar('training_loss',
                                       self.__calculate_mean_loss(stage),
                                       global_step=self.global_step)
                log_audio = False
                if self.global_step % (self.log_interval * self.log_audio_factor) == 0:
                    log_audio = True
                sample_results = self.__get_sample_results()
                for sample_result in sample_results:
                    self.__log_sample_results(
                        self.global_step, sample_result, log_audio=log_audio)

                self.tacotron.train()
                save_file = os.path.join(
                    self.save_root,
                    self.MODEL_SAVE_FORMAT.format(
                        version=self.version, step=self.global_step)
                )
                torch.save(self.tacotron.state_dict(), save_file)

            self.global_step += 1

        prog_bar.update()
        prog_bar.set_postfix(
            {'Running Loss': f'{self.__calculate_mean_loss(stage):.3f}'})

    def __log_sample_results(self, steps: int,
                             sample_result: SampleResult,
                             log_mel: bool = True,
                             log_spec: bool = True,
                             log_attention: bool = True,
                             log_audio: bool = True) -> None:
        """
        Log the sample results into tensorboard
        Args:
            steps: current step
            sample_result: sample result to log
            log_mel: if True, log mel spectrogram
            log_spec: if True, log spectrogram
            log_attention: if True, log attention
            log_audio: if True, log audio

        """
        if log_mel:
            title = f'Log Mel Spectrogram, Step:{steps}, ' \
                    f'Uid: {sample_result.uid}'

            fig = self.__get_spec_plot(
                pred_spec=sample_result.pred_mel_spec,
                truth_spec=sample_result.truth_mel_spec,
                suptitle=title,
                ylabel='Mel')
            img_tensor = self.__get_plot_tensor(fig)
            tag = f'mel_spec/{sample_result.uid}'
            self.logger.add_image(tag, img_tensor, global_step=steps)

        if log_spec:
            title = f'Log Spectrogram, Step:{steps}, ' \
                    f'Uid: {sample_result.uid}'
            fig = self.__get_spec_plot(
                pred_spec=sample_result.pred_lin_spec,
                truth_spec=sample_result.truth_lin_spec,
                suptitle=title,
                ylabel='DFT bins')
            img_tensor = self.__get_plot_tensor(fig)
            tag = f'lin_spec/{sample_result.uid}'
            self.logger.add_image(tag, img_tensor, global_step=steps)

        if log_attention:
            title = f'Attention Weight, Epoch :{steps}, ' \
                    f'Uid: {sample_result.uid}'
            fig = self.__get_attention_plot(
                title=title,
                attention_weight=sample_result.attention_weight)
            img_tensor = self.__get_plot_tensor(fig)
            tag = f'attention/{sample_result.uid}'
            self.logger.add_image(tag, img_tensor, global_step=steps)

        if log_audio:
            pred_tag = f'audio/{sample_result.uid}_predicted'
            truth_tag = f'audio/{sample_result.uid}_truth'

            self.logger.add_audio(
                tag=pred_tag,
                snd_tensor=torch.from_numpy(
                    sample_result.pred_audio).unsqueeze(1),  # add channel dim
                global_step=steps,
                sample_rate=AudioProcessParam.sr
            )

            self.logger.add_audio(
                tag=truth_tag,
                snd_tensor=torch.from_numpy(
                    sample_result.truth_audio).unsqueeze(1),  # add channel dim
                global_step=steps,
                sample_rate=AudioProcessParam.sr
            )

    def __get_sample_results(self) -> List[SampleResult]:
        """
        Get sample results to show in tensorboard, including
            1. Predicted and ground truth spectrogram pairs
            2. Predicted and ground truth mel spectrogram pairs
            3. Predicted and ground truth audio pairs
            4. Attention weight
        Returns:
            list of sample results

        """
        val_dataset = self.splitted_dataset[self.VAL_STAGE]
        self.tacotron.eval()

        test_insts = []
        with torch.no_grad():
            for subset_i in self.sample_indices:
                datapoint: TorchLJSpeechData = val_dataset[subset_i]
                datapoint: TorchLJSpeechBatch = datapoint.add_batch_dim()
                datapoint = datapoint.to(self.device)

                ds_idx = val_dataset.indices[subset_i]
                uid = val_dataset.dataset.uids[ds_idx]

                # Transcription
                transcription = val_dataset.dataset.uid_to_transcription[uid]

                wav_filepath = os.path.join(
                    val_dataset.dataset.wav_save_dir, f'{uid}.wav')
                truth_audio = AudioProcessingHelper.load_audio(wav_filepath)

                taco_output = self.tacotron.forward_train(datapoint)

                spec = taco_output.pred_lin_spec.squeeze(0).cpu().numpy().T
                pred_audio = AudioProcessingHelper.spec2audio(spec)

                test_insts.append(
                    SampleResult(
                        uid=uid,
                        transcription=transcription,
                        truth_lin_spec=datapoint.lin_spec.squeeze(0).cpu().numpy().T,
                        pred_lin_spec=taco_output.pred_lin_spec.squeeze(0).cpu().numpy().T,
                        truth_mel_spec=datapoint.mel_spec.squeeze(0).cpu().numpy().T,
                        pred_mel_spec=taco_output.pred_mel_spec.squeeze(0).cpu().numpy().T,
                        attention_weight=taco_output.attention_weight.squeeze(0).cpu().numpy(),
                        truth_audio=truth_audio,
                        pred_audio=pred_audio
                    )
                )

        return test_insts

    @staticmethod
    def __get_attention_plot(
            title: str, attention_weight: np.ndarray) -> plt.Figure:
        """
        Get figure handle for attention plot

        Args:
            title: title of the plot
            attention_weight: attention weight to plot

        Returns:
            figure object

        """
        fig = plt.figure(figsize=(6, 5), dpi=80)
        plt.title(title)
        plt.imshow(attention_weight, aspect='auto')
        plt.colorbar()
        plt.xlabel('Encoder seq')
        plt.ylabel('Decoder seq')
        plt.gca().invert_yaxis()  # Let the x, y axis start from the left-bottom corner
        plt.close(fig)
        return fig

    @staticmethod
    def __get_spec_plot(pred_spec: np.ndarray, truth_spec: np.ndarray,
                        suptitle: str, ylabel: str) -> plt.Figure:
        """
        Get a juxtaposition two spectrograms with appropriate title
        Args:
            pred_spec: predicted spectrogram
            truth_spec: ground truth spectrogram
            suptitle: title of the plot
            ylabel: unit of frequency axis of the spectrograms

        Returns:
            figure object

        """
        vmin = min(np.min(truth_spec), np.min(pred_spec))
        vmax = max(np.max(truth_spec), np.max(pred_spec))

        fig = plt.figure(figsize=(11, 5), dpi=80)
        plt.suptitle(suptitle)

        ax1 = plt.subplot(121)
        plt.title('Ground Truth')
        plt.xlabel('Frame')
        plt.ylabel(ylabel)
        plt.imshow(truth_spec, vmin=vmin, vmax=vmax, aspect='auto')
        plt.gca().invert_yaxis()  # let the x, y axis start from the left-bottom corner

        ax2 = plt.subplot(122)
        plt.title('Predicted')
        plt.xlabel('Frame')
        im = plt.imshow(pred_spec, vmin=vmin, vmax=vmax, aspect='auto')
        plt.gca().invert_yaxis()  # let the x, y axis start from the left-bottom corner

        fig.tight_layout()
        fig.colorbar(im, ax=[ax1, ax2])
        plt.close(fig)

        return fig

    @staticmethod
    def __get_plot_tensor(fig) -> torch.Tensor:
        """
        Get tensor for the given figure object
        Args:
            fig: the figure object to convert into tensor

        Returns:
            tensor of the figure

        """
        buf = io.BytesIO()
        fig.savefig(buf, format='jpeg')
        buf.seek(0)
        image = PIL.Image.open(buf)
        image = ToTensor()(image)
        return image

    def __calculate_mean_loss(self, stage: str) -> float:
        """
        Calculate mean loss for given stage (train/val)
        Args:
            stage: train/val

        Returns:
            mean loss

        """
        return self.running_loss[stage] / self.running_count[stage]

    def __reset_loss(self) -> None:
        self.running_loss = {self.TRAIN_STAGE: 0,
                             self.VAL_STAGE: 0}

    def __reset_count(self) -> None:
        self.running_count = {self.TRAIN_STAGE: 0,
                              self.VAL_STAGE: 0}

    def __split_dataset(self, dataset: TorchLJSpeechDataset) -> Dict[str, Subset]:
        """
        Split the dataset into train/validation set
        Args:
            dataset: dataset to split

        Returns:
            splitted dataset

        """
        num_train_data = int(len(dataset) * self.train_split)
        num_val_data = len(dataset) - num_train_data
        train_dataset, val_dataset = random_split(
            dataset, [num_train_data, num_val_data])

        return {self.TRAIN_STAGE: train_dataset,
                self.VAL_STAGE: val_dataset}

    def __get_dataloaders(
            self, batch_size: int, num_workers: int) -> Dict[str, DataLoader]:
        return {stage: DataLoader(
            dataset, shuffle=(stage == self.TRAIN_STAGE),
            collate_fn=TorchLJSpeechDataset.batch_tacotron_input,
            pin_memory=True, batch_size=batch_size,
            num_workers=num_workers)
            for stage, dataset in self.splitted_dataset.items()
        }