Beispiel #1
0
class Trainer:
    def __init__(self, steps=0):
        self.steps = steps
        self.epochs = 0

        self.Datset_Generate()
        self.Model_Generate()

        self.scalar_Dict = {
            'Train': defaultdict(float),
            'Evaluation': defaultdict(float),
        }

        self.writer_Dict = {
            'Train': Logger(os.path.join(hp.Log_Path, 'Train')),
            'Evaluation': Logger(os.path.join(hp.Log_Path, 'Evaluation')),
        }

        self.Load_Checkpoint()

    def Datset_Generate(self):
        train_Dataset = Dataset(
            pattern_path=hp.Train.Train_Pattern.Path,
            metadata_file=hp.Train.Train_Pattern.Metadata_File,
            accumulated_dataset_epoch=hp.Train.Train_Pattern.
            Accumulated_Dataset_Epoch,
            mel_length_min=hp.Train.Train_Pattern.Mel_Length.Min,
            mel_length_max=hp.Train.Train_Pattern.Mel_Length.Max,
            text_length_min=hp.Train.Train_Pattern.Text_Length.Min,
            text_length_max=hp.Train.Train_Pattern.Text_Length.Max,
            use_cache=hp.Train.Use_Pattern_Cache)
        dev_Dataset = Dataset(
            pattern_path=hp.Train.Eval_Pattern.Path,
            metadata_file=hp.Train.Eval_Pattern.Metadata_File,
            mel_length_min=hp.Train.Eval_Pattern.Mel_Length.Min,
            mel_length_max=hp.Train.Eval_Pattern.Mel_Length.Max,
            text_length_min=hp.Train.Eval_Pattern.Text_Length.Min,
            text_length_max=hp.Train.Eval_Pattern.Text_Length.Max,
            use_cache=hp.Train.Use_Pattern_Cache)
        inference_Dataset = Inference_Dataset(
            pattern_path=hp.Train.Inference_Pattern_File_in_Train)
        logging.info('The number of train patterns = {}.'.format(
            len(train_Dataset) //
            hp.Train.Train_Pattern.Accumulated_Dataset_Epoch))
        logging.info('The number of development patterns = {}.'.format(
            len(dev_Dataset)))
        logging.info('The number of inference patterns = {}.'.format(
            len(inference_Dataset)))

        collater = Collater()
        inference_Collater = Inference_Collater()

        self.dataLoader_Dict = {}
        self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader(
            dataset=train_Dataset,
            shuffle=True,
            collate_fn=collater,
            batch_size=hp.Train.Batch_Size,
            num_workers=hp.Train.Num_Workers,
            pin_memory=True)
        self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,
            collate_fn=collater,
            batch_size=hp.Train.Batch_Size,
            num_workers=hp.Train.Num_Workers,
            pin_memory=True)
        self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader(
            dataset=inference_Dataset,
            shuffle=False,
            collate_fn=inference_Collater,
            batch_size=hp.Inference_Batch_Size or hp.Train.Batch_Size,
            num_workers=hp.Train.Num_Workers,
            pin_memory=True)

        if hp.Mode in ['PE', 'GR']:
            self.dataLoader_Dict[
                'Prosody_Check'] = torch.utils.data.DataLoader(
                    dataset=Prosody_Check_Dataset(
                        pattern_path=hp.Train.Train_Pattern.Path,
                        metadata_file=hp.Train.Train_Pattern.Metadata_File,
                        mel_length_min=hp.Train.Train_Pattern.Mel_Length.Min,
                        mel_length_max=hp.Train.Train_Pattern.Mel_Length.Max,
                        use_cache=hp.Train.Use_Pattern_Cache),
                    shuffle=False,
                    collate_fn=Prosody_Check_Collater(),
                    batch_size=hp.Train.Batch_Size,
                    num_workers=hp.Train.Num_Workers,
                    pin_memory=True)

    def Model_Generate(self):
        self.model_Dict = {'GlowTTS': GlowTTS().to(device)}

        if not hp.Speaker_Embedding.GE2E.Checkpoint_Path is None:
            self.model_Dict['Speaker_Embedding'] = Speaker_Embedding(
                mel_dims=hp.Sound.Mel_Dim,
                lstm_size=hp.Speaker_Embedding.GE2E.LSTM.Sizes,
                lstm_stacks=hp.Speaker_Embedding.GE2E.LSTM.Stacks,
                embedding_size=hp.Speaker_Embedding.Embedding_Size,
            ).to(device)

        self.criterion_Dict = {
            'MSE': torch.nn.MSELoss().to(device),
            'MLE': MLE_Loss().to(device),
            'CE': torch.nn.CrossEntropyLoss().to(device)
        }
        self.optimizer = RAdam(params=self.model_Dict['GlowTTS'].parameters(),
                               lr=hp.Train.Learning_Rate.Initial,
                               betas=(hp.Train.ADAM.Beta1,
                                      hp.Train.ADAM.Beta2),
                               eps=hp.Train.ADAM.Epsilon,
                               weight_decay=hp.Train.Weight_Decay)
        self.scheduler = Modified_Noam_Scheduler(
            optimizer=self.optimizer, base=hp.Train.Learning_Rate.Base)

        if hp.Use_Mixed_Precision:
            self.model_Dict['GlowTTS'], self.optimizer = amp.initialize(
                models=self.model_Dict['GlowTTS'], optimizers=self.optimizer)

        logging.info(self.model_Dict['GlowTTS'])

    def Train_Step(self, tokens, token_lengths, mels, mel_lengths, speakers,
                   mels_for_ge2e, pitches):
        loss_Dict = {}

        tokens = tokens.to(device)
        token_lengths = token_lengths.to(device)
        mels = mels.to(device)
        mel_lengths = mel_lengths.to(device)
        speakers = speakers.to(device)
        mels_for_ge2e = mels_for_ge2e.to(device)
        pitches = pitches.to(device)

        z, mel_Mean, mel_Log_Std, log_Dets, log_Durations, log_Duration_Targets, _, classified_Speakers = self.model_Dict[
            'GlowTTS'](tokens=tokens,
                       token_lengths=token_lengths,
                       mels=mels,
                       mel_lengths=mel_lengths,
                       speakers=speakers,
                       mels_for_ge2e=mels_for_ge2e,
                       pitches=pitches)

        loss_Dict['MLE'] = self.criterion_Dict['MLE'](z=z,
                                                      mean=mel_Mean,
                                                      std=mel_Log_Std,
                                                      log_dets=log_Dets,
                                                      lengths=mel_lengths)
        loss_Dict['Length'] = self.criterion_Dict['MSE'](log_Durations,
                                                         log_Duration_Targets)
        loss_Dict['Total'] = loss_Dict['MLE'] + loss_Dict['Length']

        loss = loss_Dict['Total']
        if not classified_Speakers is None:
            loss_Dict['Speaker'] = self.criterion_Dict['CE'](
                classified_Speakers, speakers)
            loss = loss_Dict['Total'] + loss_Dict['Speaker']

        self.optimizer.zero_grad()
        if hp.Use_Mixed_Precision:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters=amp.master_params(
                self.optimizer),
                                           max_norm=hp.Train.Gradient_Norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=self.model_Dict['GlowTTS'].parameters(),
                max_norm=hp.Train.Gradient_Norm)
        self.optimizer.step()
        self.scheduler.step()
        self.steps += 1
        self.tqdm.update(1)

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Train']['Loss/{}'.format(tag)] += loss

    def Train_Epoch(self):
        for tokens, token_Lengths, mels, mel_Lengths, speakers, mels_for_GE2E, pitches in self.dataLoader_Dict[
                'Train']:
            self.Train_Step(tokens, token_Lengths, mels, mel_Lengths, speakers,
                            mels_for_GE2E, pitches)

            if self.steps % hp.Train.Checkpoint_Save_Interval == 0:
                self.Save_Checkpoint()

            if self.steps % hp.Train.Logging_Interval == 0:
                self.scalar_Dict['Train'] = {
                    tag: loss / hp.Train.Logging_Interval
                    for tag, loss in self.scalar_Dict['Train'].items()
                }
                self.scalar_Dict['Train'][
                    'Learning_Rate'] = self.scheduler.get_last_lr()
                self.writer_Dict['Train'].add_scalar_dict(
                    self.scalar_Dict['Train'], self.steps)
                self.scalar_Dict['Train'] = defaultdict(float)

            if self.steps % hp.Train.Evaluation_Interval == 0:
                self.Evaluation_Epoch()

            if self.steps % hp.Train.Inference_Interval == 0:
                self.Inference_Epoch()

            if self.steps >= hp.Train.Max_Step:
                return

        self.epochs += hp.Train.Train_Pattern.Accumulated_Dataset_Epoch

    @torch.no_grad()
    def Evaluation_Step(self, tokens, token_lengths, mels, mel_lengths,
                        speakers, mels_for_ge2e, pitches):
        loss_Dict = {}

        tokens = tokens.to(device)
        token_lengths = token_lengths.to(device)
        mels = mels.to(device)
        mel_lengths = mel_lengths.to(device)
        speakers = speakers.to(device)
        mels_for_ge2e = mels_for_ge2e.to(device)
        pitches = pitches.to(device)

        z, mel_Mean, mel_Log_Std, log_Dets, log_Durations, log_Duration_Targets, attentions_from_Train, classified_Speakers = self.model_Dict[
            'GlowTTS'](tokens=tokens,
                       token_lengths=token_lengths,
                       mels=mels,
                       mel_lengths=mel_lengths,
                       speakers=speakers,
                       mels_for_ge2e=mels_for_ge2e,
                       pitches=pitches)

        loss_Dict['MLE'] = self.criterion_Dict['MLE'](z=z,
                                                      mean=mel_Mean,
                                                      std=mel_Log_Std,
                                                      log_dets=log_Dets,
                                                      lengths=mel_lengths)
        loss_Dict['Length'] = self.criterion_Dict['MSE'](log_Durations,
                                                         log_Duration_Targets)
        loss_Dict['Total'] = loss_Dict['MLE'] + loss_Dict['Length']
        if not classified_Speakers is None:
            loss_Dict['Speaker'] = self.criterion_Dict['CE'](
                classified_Speakers, speakers)

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Evaluation']['Loss/{}'.format(tag)] += loss

        # For tensorboard images
        mels, _, attentions_from_Inference = self.model_Dict[
            'GlowTTS'].inference(tokens=tokens,
                                 token_lengths=token_lengths,
                                 mels_for_prosody=mels,
                                 mel_lengths_for_prosody=mel_lengths,
                                 speakers=speakers,
                                 mels_for_ge2e=mels_for_ge2e,
                                 pitches=pitches,
                                 pitch_lengths=mel_lengths,
                                 length_scale=torch.FloatTensor([1.0
                                                                 ]).to(device))

        return mels, attentions_from_Train, attentions_from_Inference, classified_Speakers

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation.'.format(self.steps))

        for model in self.model_Dict.values():
            model.eval()

        for step, (tokens, token_Lengths, mels, mel_Lengths, speakers,
                   mels_for_GE2E, pitches) in tqdm(
                       enumerate(self.dataLoader_Dict['Dev'], 1),
                       desc='[Evaluation]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Dev'].dataset) /
                           hp.Train.Batch_Size)):
            mel_Predictions, attentions_from_Train, attentions_from_Inference, classified_Speakers = self.Evaluation_Step(
                tokens, token_Lengths, mels, mel_Lengths, speakers,
                mels_for_GE2E, pitches)

        self.scalar_Dict['Evaluation'] = {
            tag: loss / step
            for tag, loss in self.scalar_Dict['Evaluation'].items()
        }
        self.writer_Dict['Evaluation'].add_scalar_dict(
            self.scalar_Dict['Evaluation'], self.steps)
        self.writer_Dict['Evaluation'].add_histogram_model(
            self.model_Dict['GlowTTS'],
            self.steps,
            delete_keywords=['layer_Dict', 'layer', 'GE2E'])
        self.scalar_Dict['Evaluation'] = defaultdict(float)

        image_Dict = {
            'Mel/Target': (mels[-1].cpu().numpy(), None),
            'Mel/Prediction': (mel_Predictions[-1].cpu().numpy(), None),
            'Attention/From_Train':
            (attentions_from_Train[-1].cpu().numpy(), None),
            'Attention/From_Inference':
            (attentions_from_Inference[-1].cpu().numpy(), None)
        }
        if not classified_Speakers is None:
            image_Dict.update({
                'Speaker/Original': (torch.nn.functional.one_hot(
                    speakers,
                    hp.Speaker_Embedding.Num_Speakers).cpu().numpy(), None),
                'Speaker/Predicted':
                (torch.softmax(classified_Speakers,
                               dim=-1).cpu().numpy(), None),
            })
        self.writer_Dict['Evaluation'].add_image_dict(image_Dict, self.steps)

        for model in self.model_Dict.values():
            model.train()

        if hp.Mode in ['PE', 'GR'
                       ] and self.steps % hp.Train.Prosody_Check_Interval == 0:
            self.Prosody_Check_Epoch()

    @torch.no_grad()
    def Inference_Step(self,
                       tokens,
                       token_lengths,
                       mels_for_prosody,
                       mel_lengths_for_prosody,
                       speakers,
                       mels_for_ge2e,
                       pitches,
                       pitch_lengths,
                       length_scales,
                       labels,
                       texts,
                       start_index=0,
                       tag_step=False,
                       tag_index=False):
        tokens = tokens.to(device)
        token_lengths = token_lengths.to(device)
        mels_for_prosody = mels_for_prosody.to(device)
        mel_lengths_for_prosody = mel_lengths_for_prosody.to(device)
        speakers = speakers.to(device)
        mels_for_ge2e = mels_for_ge2e.to(device)
        pitches = pitches.to(device)
        length_scales = length_scales.to(device)

        mels, mel_Lengths, attentions = self.model_Dict['GlowTTS'].inference(
            tokens=tokens,
            token_lengths=token_lengths,
            mels_for_prosody=mels_for_prosody,
            mel_lengths_for_prosody=mel_lengths_for_prosody,
            speakers=speakers,
            mels_for_ge2e=mels_for_ge2e,
            pitches=pitches,
            pitch_lengths=pitch_lengths,
            length_scale=length_scales)

        files = []
        for index, label in enumerate(labels):
            tags = []
            if tag_step: tags.append('Step-{}'.format(self.steps))
            tags.append(label)
            if tag_index: tags.append('IDX_{}'.format(index + start_index))
            files.append('.'.join(tags))

        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps),
                                 'PNG').replace('\\', '/'),
                    exist_ok=True)
        for index, (mel, mel_Length, attention, label, text, length_Scale,
                    file) in enumerate(
                        zip(mels.cpu().numpy(),
                            mel_Lengths.cpu().numpy(),
                            attentions.cpu().numpy(), labels, texts,
                            length_scales, files)):
            mel = mel[:, :mel_Length]
            attention = attention[:len(text) + 2, :mel_Length]

            new_Figure = plt.figure(figsize=(20, 5 * 3), dpi=100)
            plt.subplot2grid((3, 1), (0, 0))
            plt.imshow(mel, aspect='auto', origin='lower')
            plt.title(
                'Mel    Label: {}    Text: {}    Length scale: {:.3f}'.format(
                    label, text if len(text) < 70 else text[:70] + '…',
                    length_Scale))
            plt.colorbar()
            plt.subplot2grid((3, 1), (1, 0), rowspan=2)
            plt.imshow(attention,
                       aspect='auto',
                       origin='lower',
                       interpolation='none')
            plt.title(
                'Attention    Label: {}    Text: {}    Length scale: {:.3f}'.
                format(label, text if len(text) < 70 else text[:70] + '…',
                       length_Scale))
            plt.yticks(range(len(text) + 2), ['<S>'] + list(text) + ['<E>'],
                       fontsize=10)
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(
                os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps),
                             'PNG', '{}.PNG'.format(file)).replace('\\', '/'))
            plt.close(new_Figure)

        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps),
                                 'NPY').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY',
                                 'Mel').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY',
                                 'Attention').replace('\\', '/'),
                    exist_ok=True)

        for index, (mel, mel_Length, file) in enumerate(
                zip(mels.cpu().numpy(),
                    mel_Lengths.cpu().numpy(), files)):
            mel = mel[:, :mel_Length]
            attention = attention[:len(text) + 2, :mel_Length]

            np.save(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY', 'Mel',
                                 file).replace('\\', '/'),
                    mel.T,
                    allow_pickle=False)
            np.save(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY',
                                 'Attention', file).replace('\\', '/'),
                    attentions.cpu().numpy()[index],
                    allow_pickle=False)

    def Inference_Epoch(self):
        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        for model in self.model_Dict.values():
            model.eval()

        for step, (tokens, token_Lengths, mels_for_Prosody,
                   mel_Lengths_for_Prosody, speakers, mels_for_GE2E, pitches,
                   pitch_Lengths, length_Scales, labels, texts) in tqdm(
                       enumerate(self.dataLoader_Dict['Inference']),
                       desc='[Inference]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Inference'].dataset) /
                           (hp.Inference_Batch_Size or hp.Train.Batch_Size))):
            self.Inference_Step(
                tokens,
                token_Lengths,
                mels_for_Prosody,
                mel_Lengths_for_Prosody,
                speakers,
                mels_for_GE2E,
                pitches,
                pitch_Lengths,
                length_Scales,
                labels,
                texts,
                start_index=step *
                (hp.Inference_Batch_Size or hp.Train.Batch_Size))

        for model in self.model_Dict.values():
            model.train()

    @torch.no_grad()
    def Prosody_Check_Step(self, mels, mel_lengths):
        mels = mels.to(device)
        mel_lengths = mel_lengths.to(device)
        prosodies = self.model_Dict['GlowTTS'].layer_Dict['Prosody_Encoder'](
            mels, mel_lengths)

        return prosodies

    def Prosody_Check_Epoch(self):
        logging.info('(Steps: {}) Start prosody check.'.format(self.steps))

        for model in self.model_Dict.values():
            model.eval()

        prosodies, labels = zip(
            *[(self.Prosody_Check_Step(mels, mel_Lengths), labels)
              for mels, mel_Lengths, labels in tqdm(
                  self.dataLoader_Dict['Prosody_Check'],
                  desc='[Prosody_Check]',
                  total=math.ceil(
                      len(self.dataLoader_Dict['Prosody_Check'].dataset) /
                      hp.Train.Batch_Size))])
        prosodies = torch.cat(prosodies, dim=0)
        labels = [label for sub_labels in labels for label in sub_labels]

        self.writer_Dict['Evaluation'].add_embedding(prosodies.cpu().numpy(),
                                                     metadata=labels,
                                                     global_step=self.steps,
                                                     tag='Prosodies')

        for model in self.model_Dict.values():
            model.train()

    def Load_Checkpoint(self):
        if self.steps == 0:
            paths = [
                os.path.join(root, file).replace('\\', '/')
                for root, _, files in os.walk(hp.Checkpoint_Path)
                for file in files if os.path.splitext(file)[1] == '.pt'
            ]
            if len(paths) > 0:
                path = max(paths, key=os.path.getctime)
            else:
                return  # Initial training
        else:
            path = os.path.join(
                hp.Checkpoint_Path,
                'S_{}.pt'.format(self.steps).replace('\\', '/'))

        state_Dict = torch.load(path, map_location='cpu')
        self.model_Dict['GlowTTS'].load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']
        self.epochs = state_Dict['Epochs']

        if hp.Use_Mixed_Precision:
            if not 'AMP' in state_Dict.keys():
                logging.info(
                    'No AMP state dict is in the checkpoint. Model regards this checkpoint is trained without mixed precision.'
                )
            else:
                amp.load_state_dict(state_Dict['AMP'])

        for flow in self.model_Dict['GlowTTS'].layer_Dict[
                'Decoder'].layer_Dict['Flows']:
            flow.layers[
                0].initialized = True  # Activation_Norm is already initialized when checkpoint is loaded.

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

        if 'GE2E' in self.model_Dict['GlowTTS'].layer_Dict.keys(
        ) and self.steps == 0:
            self.GE2E_Load_Checkpoint()

    def Save_Checkpoint(self):
        os.makedirs(hp.Checkpoint_Path, exist_ok=True)

        state_Dict = {
            'Model': self.model_Dict['GlowTTS'].state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps,
            'Epochs': self.epochs,
        }
        if hp.Use_Mixed_Precision:
            state_Dict['AMP'] = amp.state_dict()

        torch.save(
            state_Dict,
            os.path.join(hp.Checkpoint_Path,
                         'S_{}.pt'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def GE2E_Load_Checkpoint(self):
        state_Dict = torch.load(hp.Speaker_Embedding.GE2E.Checkpoint_Path,
                                map_location='cpu')
        self.model_Dict['GlowTTS'].layer_Dict['GE2E'].load_state_dict(
            state_Dict['Model'])
        logging.info('Speaker embedding checkpoint \'{}\' loaded.'.format(
            hp.Speaker_Embedding.GE2E.Checkpoint_Path))

    def Train(self):
        hp_Path = os.path.join(hp.Checkpoint_Path,
                               'Hyper_Parameters.yaml').replace('\\', '/')
        if not os.path.exists(hp_Path):
            from shutil import copyfile
            os.makedirs(hp.Checkpoint_Path, exist_ok=True)
            copyfile('Hyper_Parameters.yaml', hp_Path)

        if self.steps == 0:
            self.Evaluation_Epoch()

        if hp.Train.Initial_Inference:
            self.Inference_Epoch()

        self.tqdm = tqdm(initial=self.steps,
                         total=hp.Train.Max_Step,
                         desc='[Training]')

        while self.steps < hp.Train.Max_Step:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')
Beispiel #2
0
def run_train(train_data_file, dev_data_file):

    print("1. load config and dict")
    vocab_file = open(config.data_path + "vocab.txt", "r", encoding="utf-8")
    vocab_list = [word.strip() for word in vocab_file]
    if not os.path.exists(config.data_path + "emb_word.txt"):
        emb_file = "D:/emb/glove.6B/glove.6B.300d.txt"
        embeddings = read_emb(emb_file, vocab_list)
        emb_write = open(config.data_path + "/emb_word.txt",
                         "w",
                         encoding="utf-8")
        for emb in embeddings:
            emb_write.write(emb)
        emb_write.close()
    else:
        embedding_file = open(config.data_path + "emb_word.txt",
                              "r",
                              encoding="utf-8")
        embeddings = [emb.strip() for emb in embedding_file]
    embedding_word, vocab = process_emb(embeddings, emb_dim=config.emb_dim)

    idx2intent, intent2idx = lord_label_dict(config.data_path +
                                             "intent_label.txt")
    idx2slot, slot2idx = lord_label_dict(config.data_path + "slot_label.txt")
    n_slot_tag = len(idx2slot.items())
    n_intent_class = len(idx2intent.items())

    train_dir = os.path.join(config.data_path, train_data_file)
    dev_dir = os.path.join(config.data_path, dev_data_file)
    train_loader = read_corpus(train_dir,
                               max_length=config.max_len,
                               intent2idx=intent2idx,
                               slot2idx=slot2idx,
                               vocab=vocab,
                               is_train=True)
    dev_loader = read_corpus(dev_dir,
                             max_length=config.max_len,
                             intent2idx=intent2idx,
                             slot2idx=slot2idx,
                             vocab=vocab,
                             is_train=False)
    model = Joint_model(config, config.hidden_dim, config.batch_size,
                        config.max_len, n_intent_class, n_slot_tag,
                        embedding_word)

    if use_cuda:
        model.cuda()
    model.train()
    optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0.000001)
    # optimizer = getattr(optim,"Adam")
    # optimizer = optimizer(model.parameters(), lr = config.lr, weight_decay=0.00001)
    best_slot_f1 = [0.0, 0.0, 0.0]
    best_intent_acc = [0.0, 0.0, 0.0]
    best_sent_acc = [0.0, 0.0, 0.0]
    # best_slot_f1 = 0.0
    # best_intent_acc = 0.0
    # best_sent_acc = 0.0
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, [40, 60, 80], gamma=config.lr_scheduler_gama, last_epoch=-1)

    for epoch in trange(config.epoch, desc="Epoch"):
        print(scheduler.get_lr())
        step = 0
        for i, batch in enumerate(tqdm(train_loader, desc="batch_nums")):
            step += 1
            model.zero_grad()
            inputs, char_lists, slot_labels, intent_labels, masks, = batch
            # inputs, masks, tags = Variable(inputs), Variable(masks), Variable(tags)
            if use_cuda:
                inputs, char_lists, masks, intent_labels, slot_labels = \
                    inputs.cuda(), char_lists.cuda(), masks.cuda(), intent_labels.cuda(), slot_labels.cuda()
            logits_intent, logits_slot = model.forward_logit(
                (inputs, char_lists), masks)
            loss_intent, loss_slot, = model.loss1(logits_intent, logits_slot,
                                                  intent_labels, slot_labels,
                                                  masks)

            if epoch < 40:
                loss = loss_slot + loss_intent
            else:
                loss = 0.8 * loss_intent + 0.2 * loss_slot
            loss.backward()
            optimizer.step()

            if step % 100 == 0:
                print("loss domain:", loss.item())
                print('epoch: {}|    step: {} |    loss: {}'.format(
                    epoch, step, loss.item()))

        intent_acc, slot_f1, sent_acc = dev(model, dev_loader, idx2slot)
        # if slot_f1 > best_slot_f1 or intent_acc > best_intent_acc or sent_acc > best_sent_acc:
        #     torch.save(model, config.model_save_dir + config.model_name)
        if slot_f1 > best_slot_f1[1]:
            best_slot_f1 = [sent_acc, slot_f1, intent_acc, epoch]
            torch.save(model, config.model_save_dir + config.model_path)
        if intent_acc > best_intent_acc[2]:
            torch.save(model, config.model_save_dir + config.model_path)
            best_intent_acc = [sent_acc, slot_f1, intent_acc, epoch]
        if sent_acc > best_sent_acc[0]:
            torch.save(model, config.model_save_dir + config.model_path)
            best_sent_acc = [sent_acc, slot_f1, intent_acc, epoch]
        scheduler.step()
    print("best_slot_f1:", best_slot_f1)
    print("best_intent_acc:", best_intent_acc)
    print("best_sent_acc:", best_sent_acc)
Beispiel #3
0
class Trainer:
    def __init__(self, steps=0):
        self.steps = steps
        self.epochs = 0

        self.Datset_Generate()
        self.Model_Generate()

        self.scalar_Dict = {
            'Train': defaultdict(float),
            'Evaluation': defaultdict(float),
        }

        self.writer_Dict = {
            'Train': Logger(os.path.join(hp_Dict['Log_Path'], 'Train')),
            'Evaluation':
            Logger(os.path.join(hp_Dict['Log_Path'], 'Evaluation')),
        }

        self.Load_Checkpoint()

    def Datset_Generate(self):
        train_Dataset = Train_Dataset()
        accumulation_Dataset = Accumulation_Dataset()
        dev_Dataset = Dev_Dataset()
        inference_Dataset = Inference_Dataset()
        logging.info('The number of base train files = {}.'.format(
            len(train_Dataset) //
            hp_Dict['Train']['Train_Pattern']['Accumulated_Dataset_Epoch']))
        logging.info('The number of development patterns = {}.'.format(
            len(dev_Dataset)))
        logging.info('The number of inference patterns = {}.'.format(
            len(inference_Dataset)))

        collater = Collater()
        accumulation_Collater = Accumulation_Collater()
        inference_Collater = Inference_Collater()

        self.dataLoader_Dict = {}
        self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader(
            dataset=train_Dataset,
            shuffle=True,
            collate_fn=collater,
            batch_size=hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Accumulation'] = torch.utils.data.DataLoader(
            dataset=accumulation_Dataset,
            shuffle=False,
            collate_fn=accumulation_Collater,
            batch_size=hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,  # to write tensorboard.
            collate_fn=collater,
            batch_size=hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader(
            dataset=inference_Dataset,
            shuffle=False,  # to write tensorboard.
            collate_fn=inference_Collater,
            batch_size=hp_Dict['Inference_Batch_Size']
            or hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)

    def Model_Generate(self):
        self.model = PVCGAN().to(device)
        self.criterion_Dict = {
            'STFT':
            MultiResolutionSTFTLoss(
                fft_sizes=hp_Dict['STFT_Loss_Resolution']['FFT_Sizes'],
                shift_lengths=hp_Dict['STFT_Loss_Resolution']['Shfit_Lengths'],
                win_lengths=hp_Dict['STFT_Loss_Resolution']['Win_Lengths'],
            ).to(device),
            'MSE':
            torch.nn.MSELoss().to(device),
            'CE':
            torch.nn.CrossEntropyLoss().to(device),
            'MAE':
            torch.nn.L1Loss().to(device)
        }

        self.optimizer = RAdam(params=self.model.parameters(),
                               lr=hp_Dict['Train']['Learning_Rate']['Initial'],
                               betas=(hp_Dict['Train']['ADAM']['Beta1'],
                                      hp_Dict['Train']['ADAM']['Beta2']),
                               eps=hp_Dict['Train']['ADAM']['Epsilon'],
                               weight_decay=hp_Dict['Train']['Weight_Decay'])
        self.scheduler = Modified_Noam_Scheduler(
            optimizer=self.optimizer,
            base=hp_Dict['Train']['Learning_Rate']['Base'])

        if hp_Dict['Use_Mixed_Precision']:
            self.model, self.optimizer = amp.initialize(
                models=self.model, optimizers=self.optimizer)

        logging.info(self.model)

    def Train_Step(self, audios, mels, pitches, audio_Singers, mel_Singers,
                   noises):
        loss_Dict = {}

        audios = audios.to(device, non_blocking=True)
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        audio_Singers = audio_Singers.to(device, non_blocking=True)
        mel_Singers = mel_Singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        # Generator
        fakes, predicted_Singers, predicted_Pitches, fakes_Discriminations, reals_Discriminations = self.model(
            mels=mels,
            pitches=pitches,
            singers=audio_Singers,
            noises=noises,
            discrimination=self.steps >=
            hp_Dict['Train']['Discriminator_Delay'],
            reals=audios)

        loss_Dict['Generator/Spectral_Convergence'], loss_Dict[
            'Generator/Magnitude'] = self.criterion_Dict['STFT'](fakes, audios)
        loss_Dict['Generator'] = loss_Dict[
            'Generator/Spectral_Convergence'] + loss_Dict['Generator/Magnitude']

        loss_Dict['Confuser/Singer'] = self.criterion_Dict['CE'](
            predicted_Singers, mel_Singers)
        loss_Dict['Confuser/Pitch'] = self.criterion_Dict['MAE'](
            predicted_Pitches, pitches)
        loss_Dict['Confuser'] = loss_Dict['Confuser/Singer'] + loss_Dict[
            'Confuser/Pitch']
        loss = loss_Dict['Generator'] + loss_Dict['Confuser']

        if self.steps >= hp_Dict['Train']['Discriminator_Delay']:
            loss_Dict['Discriminator/Fake'] = self.criterion_Dict['MSE'](
                fakes_Discriminations, torch.zeros_like(fakes_Discriminations)
            )  # Discriminator thinks that 0 is correct.
            loss_Dict['Discriminator/Real'] = self.criterion_Dict['MSE'](
                reals_Discriminations, torch.ones_like(reals_Discriminations)
            )  # Discriminator thinks that 1 is correct.

            loss_Dict['Discriminator'] = loss_Dict[
                'Discriminator/Fake'] + loss_Dict['Discriminator/Real']
            loss += loss_Dict['Discriminator']

        self.optimizer.zero_grad()

        if hp_Dict['Use_Mixed_Precision']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=amp.master_params(self.optimizer),
                max_norm=hp_Dict['Train']['Gradient_Norm'])
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=self.model.parameters(),
                max_norm=hp_Dict['Train']['Gradient_Norm'])

        self.optimizer.step()
        self.scheduler.step()

        self.steps += 1
        self.tqdm.update(1)

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Train']['Loss/{}'.format(tag)] += loss

    def Train_Epoch(self):
        if any([
                hp_Dict['Train']['Train_Pattern']['Mixup']['Use']
                and self.steps >=
                hp_Dict['Train']['Train_Pattern']['Mixup']['Apply_Delay'],
                hp_Dict['Train']['Train_Pattern']['Back_Translate']['Use']
                and self.steps >= hp_Dict['Train']['Train_Pattern']
            ['Back_Translate']['Apply_Delay']
        ]):
            self.Data_Accumulation()

        for audios, mels, pitches, audio_Singers, mel_Singers, noises in self.dataLoader_Dict[
                'Train']:
            self.Train_Step(audios, mels, pitches, audio_Singers, mel_Singers,
                            noises)

            if self.steps % hp_Dict['Train']['Checkpoint_Save_Interval'] == 0:
                self.Save_Checkpoint()

            if self.steps % hp_Dict['Train']['Logging_Interval'] == 0:
                self.scalar_Dict['Train'] = {
                    tag: loss / hp_Dict['Train']['Logging_Interval']
                    for tag, loss in self.scalar_Dict['Train'].items()
                }
                self.scalar_Dict['Train'][
                    'Learning_Rate'] = self.scheduler.get_last_lr()
                self.writer_Dict['Train'].add_scalar_dict(
                    self.scalar_Dict['Train'], self.steps)
                self.scalar_Dict['Train'] = defaultdict(float)

            if self.steps % hp_Dict['Train']['Evaluation_Interval'] == 0:
                self.Evaluation_Epoch()

            if self.steps % hp_Dict['Train']['Inference_Interval'] == 0:
                self.Inference_Epoch()

            if self.steps >= hp_Dict['Train']['Max_Step']:
                return

        self.epochs += hp_Dict['Train']['Train_Pattern'][
            'Accumulated_Dataset_Epoch']

    @torch.no_grad()
    def Evaluation_Step(self, audios, mels, pitches, audio_Singers,
                        mel_Singers, noises):
        loss_Dict = {}

        audios = audios.to(device, non_blocking=True)
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        audio_Singers = audio_Singers.to(device, non_blocking=True)
        mel_Singers = mel_Singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        # Generator
        fakes, predicted_Singers, predicted_Pitches, fakes_Discriminations, reals_Discriminations = self.model(
            mels=mels,
            pitches=pitches,
            singers=audio_Singers,
            noises=noises,
            discrimination=self.steps >=
            hp_Dict['Train']['Discriminator_Delay'],
            reals=audios)

        loss_Dict['Generator/Spectral_Convergence'], loss_Dict[
            'Generator/Magnitude'] = self.criterion_Dict['STFT'](fakes, audios)
        loss_Dict['Generator'] = loss_Dict[
            'Generator/Spectral_Convergence'] + loss_Dict['Generator/Magnitude']

        loss_Dict['Confuser/Singer'] = self.criterion_Dict['CE'](
            predicted_Singers, mel_Singers)
        loss_Dict['Confuser/Pitch'] = self.criterion_Dict['MAE'](
            predicted_Pitches, pitches)
        loss_Dict['Confuser'] = loss_Dict['Confuser/Singer'] + loss_Dict[
            'Confuser/Pitch']
        loss = loss_Dict['Generator'] + loss_Dict['Confuser']

        if self.steps >= hp_Dict['Train']['Discriminator_Delay']:
            loss_Dict['Discriminator/Fake'] = self.criterion_Dict['MSE'](
                fakes_Discriminations, torch.zeros_like(fakes_Discriminations)
            )  # Discriminator thinks that 0 is correct.
            loss_Dict['Discriminator/Real'] = self.criterion_Dict['MSE'](
                reals_Discriminations, torch.ones_like(reals_Discriminations)
            )  # Discriminator thinks that 1 is correct.

            loss_Dict['Discriminator'] = loss_Dict[
                'Discriminator/Fake'] + loss_Dict['Discriminator/Real']
            loss += loss_Dict['Discriminator']

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Evaluation']['Loss/{}'.format(tag)] += loss

        return fakes, predicted_Singers, predicted_Pitches

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation.'.format(self.steps))

        self.model.eval()

        for step, (audios, mels, pitches, audio_Singers, mel_Singers,
                   noises) in tqdm(
                       enumerate(self.dataLoader_Dict['Dev'], 1),
                       desc='[Evaluation]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Dev'].dataset) /
                           hp_Dict['Train']['Batch_Size'])):
            fakes, predicted_Singers, predicted_Pitches = self.Evaluation_Step(
                audios, mels, pitches, audio_Singers, mel_Singers, noises)

        self.scalar_Dict['Evaluation'] = {
            tag: loss / step
            for tag, loss in self.scalar_Dict['Evaluation'].items()
        }
        self.writer_Dict['Evaluation'].add_scalar_dict(
            self.scalar_Dict['Evaluation'], self.steps)
        self.writer_Dict['Evaluation'].add_histogram_model(
            self.model,
            self.steps,
            delete_keywords=['layer_Dict', '1', 'layer'])
        self.scalar_Dict['Evaluation'] = defaultdict(float)

        self.writer_Dict['Evaluation'].add_image_dict(
            {
                'Mel': (mels[-1].cpu().numpy(), None),
                'Audio/Original': (audios[-1].cpu().numpy(), None),
                'Audio/Predicted': (fakes[-1].cpu().numpy(), None),
                'Pitch/Original': (pitches[-1].cpu().numpy(), None),
                'Pitch/Predicted': (predicted_Pitches[-1].cpu().numpy(), None),
                'Singer/Original': (torch.nn.functional.one_hot(
                    mel_Singers, hp_Dict['Num_Singers']).cpu().numpy(), None),
                'Singer/Predicted':
                (torch.softmax(predicted_Singers, dim=-1).cpu().numpy(), None),
            }, self.steps)

        self.model.train()

    @torch.no_grad()
    def Inference_Step(self,
                       audios,
                       mels,
                       pitches,
                       singers,
                       noises,
                       source_Labels,
                       singer_Labels,
                       start_Index=0,
                       tag_step=False,
                       tag_Index=False):
        audios = audios.to(device, non_blocking=True)
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        singers = singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        fakes, *_ = self.model(mels=mels,
                               pitches=pitches,
                               singers=singers,
                               noises=noises)

        files = []
        for index, (source_Label, singer_Label) in enumerate(
                zip(source_Labels, singer_Labels)):
            tags = []
            if tag_step: tags.append('Step-{}'.format(self.steps))
            tags.append('{}_to_{}'.format(source_Label, singer_Label))
            if tag_Index: tags.append('IDX_{}'.format(index + start_Index))
            files.append('.'.join(tags))

        os.makedirs(os.path.join(hp_Dict['Inference_Path'],
                                 'Step-{}'.format(self.steps),
                                 'PNG').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(hp_Dict['Inference_Path'],
                                 'Step-{}'.format(self.steps),
                                 'WAV').replace("\\", "/"),
                    exist_ok=True)
        for index, (real, fake, source_Label, singer_Label, file) in enumerate(
                zip(audios.cpu().numpy(),
                    fakes.cpu().numpy(), source_Labels, singer_Labels, files)):
            new_Figure = plt.figure(figsize=(80, 10 * 2), dpi=100)
            plt.subplot(211)
            plt.plot(real)
            plt.title('Original wav    Index: {}    {} -> {}'.format(
                index + start_Index, source_Label, singer_Label))
            plt.subplot(212)
            plt.plot(fake)
            plt.title('Fake wav    Index: {}    {} -> {}'.format(
                index + start_Index, source_Label, singer_Label))
            plt.tight_layout()
            plt.savefig(
                os.path.join(hp_Dict['Inference_Path'],
                             'Step-{}'.format(self.steps), 'PNG',
                             '{}.png'.format(file)).replace('\\', '/'))
            plt.close(new_Figure)

            wavfile.write(filename=os.path.join(
                hp_Dict['Inference_Path'], 'Step-{}'.format(self.steps), 'WAV',
                '{}.wav'.format(file)).replace('\\', '/'),
                          data=(np.clip(fake, -1.0 + 1e-7, 1.0 - 1e-7) *
                                32767.5).astype(np.int16),
                          rate=hp_Dict['Sound']['Sample_Rate'])

    def Inference_Epoch(self):
        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        self.model.eval()

        for step, (audios, mels, pitches, singers, noises, source_Labels,
                   singer_Labels) in tqdm(
                       enumerate(self.dataLoader_Dict['Inference'], 1),
                       desc='[Inference]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Inference'].dataset) /
                           (hp_Dict['Inference_Batch_Size']
                            or hp_Dict['Train']['Batch_Size']))):
            self.Inference_Step(audios,
                                mels,
                                pitches,
                                singers,
                                noises,
                                source_Labels,
                                singer_Labels,
                                start_Index=(step - 1) *
                                hp_Dict['Train']['Batch_Size'])

        self.model.train()

    @torch.no_grad()
    def Back_Translate_Step(self, mels, pitches, singers, noises):
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        singers = singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        fakes, *_ = self.model(mels=mels,
                               pitches=pitches,
                               singers=singers,
                               noises=noises)

        return fakes.cpu().numpy()

    def Data_Accumulation(self):
        def Mixup(audio, pitch):
            max_Offset = pitch.shape[0] - hp_Dict['Train'][
                'Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2
            offset1 = np.random.randint(0, max_Offset)
            offset2 = np.random.randint(0, max_Offset)
            beta = np.random.uniform(
                low=hp_Dict['Train']['Train_Pattern']['Mixup']['Min_Beta'],
                high=hp_Dict['Train']['Train_Pattern']['Mixup']['Max_Beta'],
            )

            new_Audio = \
                beta * audio[offset1 * hp_Dict['Sound']['Frame_Shift']:offset1 * hp_Dict['Sound']['Frame_Shift'] + hp_Dict['Train']['Wav_Length'] * 2] + \
                (1 - beta) * audio[offset2* hp_Dict['Sound']['Frame_Shift']:offset2 * hp_Dict['Sound']['Frame_Shift'] + hp_Dict['Train']['Wav_Length'] * 2]

            new_Pitch = \
                beta * pitch[offset1:offset1 + hp_Dict['Train']['Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2] + \
                (1 - beta) * pitch[offset2:offset2 + hp_Dict['Train']['Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2]

            _, new_Mel, _, _ = Pattern_Generate(audio=new_Audio)

            return new_Audio, new_Mel, new_Pitch

        def Back_Translate(mels, pitches, singers, noises):
            fakes = self.Back_Translate_Step(mels=mels,
                                             pitches=pitches,
                                             singers=singers,
                                             noises=noises)

            new_Mels = [Pattern_Generate(audio=fake)[1] for fake in fakes]

            return new_Mels

        print()
        mixup_List = []
        back_Translate_List = []
        for total_Audios, total_Pitches, audios, mels, pitches, singers, noises in tqdm(
                self.dataLoader_Dict['Accumulation'],
                desc='[Accumulation]',
                total=math.ceil(
                    len(self.dataLoader_Dict['Accumulation'].dataset) /
                    hp_Dict['Train']['Batch_Size'])):
            #Mixup
            if hp_Dict['Train']['Train_Pattern']['Mixup'][
                    'Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'][
                        'Mixup']['Apply_Delay']:
                for audio, pitch, singer in zip(total_Audios, total_Pitches,
                                                singers.numpy()):
                    mixup_Audio, mixup_Mel, mixup_Pitch = Mixup(audio, pitch)
                    mixup_List.append(
                        (mixup_Audio, mixup_Mel, mixup_Pitch, singer, singer))

            #Backtranslate
            if hp_Dict['Train']['Train_Pattern']['Back_Translate'][
                    'Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'][
                        'Back_Translate']['Apply_Delay']:
                mel_Singers = torch.LongTensor(
                    np.stack([
                        choice([
                            x for x in range(hp_Dict['Num_Singers'])
                            if x != singer
                        ]) for singer in singers
                    ],
                             axis=0))
                back_Translate_Mels = Back_Translate(mels, pitches,
                                                     mel_Singers, noises)
                for audio, back_Translate_Mel, pitch, audio_Singer, mel_Singer in zip(
                        audios.numpy(), back_Translate_Mels, pitches.numpy(),
                        singers.numpy(), mel_Singers.numpy()):
                    back_Translate_List.append(
                        (audio, back_Translate_Mel, pitch, audio_Singer,
                         mel_Singer))

        self.dataLoader_Dict['Train'].dataset.Accumulation_Renew(
            mixup_Pattern_List=mixup_List,
            back_Translate_Pattern_List=back_Translate_List)

    def Load_Checkpoint(self):
        if self.steps == 0:
            paths = [
                os.path.join(root, file).replace('\\', '/')
                for root, _, files in os.walk(hp_Dict['Checkpoint_Path'])
                for file in files if os.path.splitext(file)[1] == '.pt'
            ]
            if len(paths) > 0:
                path = max(paths, key=os.path.getctime)
            else:
                return  # Initial training
        else:
            path = os.path.join(
                hp_Dict['Checkpoint_Path'],
                'S_{}.pt'.format(self.steps).replace('\\', '/'))

        state_Dict = torch.load(path, map_location='cpu')
        self.model.load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']
        self.epochs = state_Dict['Epochs']

        if hp_Dict['Use_Mixed_Precision']:
            if not 'AMP' in state_Dict.keys():
                logging.info(
                    'No AMP state dict is in the checkpoint. Model regards this checkpoint is trained without mixed precision.'
                )
            else:
                amp.load_state_dict(state_Dict['AMP'])

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

    def Save_Checkpoint(self):
        os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True)

        state_Dict = {
            'Model': self.model.state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps,
            'Epochs': self.epochs,
        }
        if hp_Dict['Use_Mixed_Precision']:
            state_Dict['AMP'] = amp.state_dict()

        torch.save(
            state_Dict,
            os.path.join(hp_Dict['Checkpoint_Path'],
                         'S_{}.pt'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def Train(self):
        if not os.path.exists(
                os.path.join(hp_Dict['Checkpoint_Path'],
                             'Hyper_Parameters.yaml')):
            os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True)
            with open(
                    os.path.join(hp_Dict['Checkpoint_Path'],
                                 'Hyper_Parameters.yaml').replace("\\", "/"),
                    "w") as f:
                yaml.dump(hp_Dict, f)

        if self.steps == 0:
            self.Evaluation_Epoch()

        if hp_Dict['Train']['Initial_Inference']:
            self.Inference_Epoch()

        self.tqdm = tqdm(initial=self.steps,
                         total=hp_Dict['Train']['Max_Step'],
                         desc='[Training]')

        while self.steps < hp_Dict['Train']['Max_Step']:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')
Beispiel #4
0
class Trainer:
    def __init__(self, steps=0):
        self.steps = steps
        self.epochs = 0

        self.Datset_Generate()
        self.Model_Generate()

        self.writer = SummaryWriter(hp_Dict['Log_Path'])

        if self.steps > 0:
            self.Load_Checkpoint()

    def Datset_Generate(self):
        train_Dataset = Train_Dataset()
        dev_Dataset = Dev_Dataset()
        logging.info('The number of train files = {}.'.format(
            len(train_Dataset)))
        logging.info('The number of development files = {}.'.format(
            len(dev_Dataset)))

        train_Collater = Train_Collater()
        dev_Collater = Dev_Collater()
        inference_Collater = Inference_Collater()

        self.dataLoader_Dict = {}
        self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader(
            dataset=train_Dataset,
            shuffle=True,
            collate_fn=train_Collater,
            batch_size=hp_Dict['Train']['Batch']['Train']['Speaker'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,
            collate_fn=dev_Collater,
            batch_size=hp_Dict['Train']['Batch']['Eval']['Speaker'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,
            collate_fn=inference_Collater,
            batch_size=hp_Dict['Train']['Batch']['Eval']['Speaker'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)

    def Model_Generate(self):
        self.model = Encoder(
            mel_dims=hp_Dict['Sound']['Mel_Dim'],
            lstm_size=hp_Dict['Encoder']['LSTM']['Sizes'],
            lstm_stacks=hp_Dict['Encoder']['LSTM']['Stacks'],
            embedding_size=hp_Dict['Encoder']['Embedding_Size'],
        ).to(device)
        self.criterion = GE2E_Loss().to(device)
        self.optimizer = RAdam(
            params=self.model.parameters(),
            lr=hp_Dict['Train']['Learning_Rate']['Initial'],
            eps=hp_Dict['Train']['Learning_Rate']['Epsilon'],
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=self.optimizer,
            step_size=hp_Dict['Train']['Learning_Rate']['Decay_Step'],
            gamma=hp_Dict['Train']['Learning_Rate']['Decay_Rate'],
        )

        logging.info(self.model)

    def Train_Step(self, mels):
        mels = mels.to(device)
        embeddings = self.model(mels)
        loss = self.criterion(
            embeddings,
            hp_Dict['Train']['Batch']['Train']['Pattern_per_Speaker'], device)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            parameters=self.model.parameters(),
            max_norm=hp_Dict['Train']['Gradient_Norm'])
        self.optimizer.step()
        self.scheduler.step()

        self.steps += 1
        self.tqdm.update(1)

        self.train_Losses += loss

    def Train_Epoch(self):
        for mels in self.dataLoader_Dict['Train']:
            self.Train_Step(mels)

            if self.steps % hp_Dict['Train']['Checkpoint_Save_Interval'] == 0:
                self.Save_Checkpoint()

            if self.steps % hp_Dict['Train']['Logging_Interval'] == 0:
                self.writer.add_scalar(
                    'train/loss',
                    self.train_Losses / hp_Dict['Train']['Logging_Interval'],
                    self.steps)
                self.train_Losses = 0.0

            if self.steps % hp_Dict['Train']['Evaluation_Interval'] == 0:
                self.Evaluation_Epoch()
                self.Inference_Epoch()

            if self.steps >= hp_Dict['Train']['Max_Step']:
                return

        self.epochs += 1

    @torch.no_grad()
    def Evaluation_Step(self, mels):
        mels = mels.to(device)
        embeddings = self.model(mels)
        loss = self.criterion(
            embeddings,
            hp_Dict['Train']['Batch']['Eval']['Pattern_per_Speaker'], device)

        return embeddings, loss

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation.'.format(self.steps))

        self.model.eval()

        embeddings, losses, datasets, speakers = zip(
            *[(*self.Evaluation_Step(mels), datasets, speakers)
              for step, (
                  mels, datasets,
                  speakers) in tqdm(enumerate(self.dataLoader_Dict['Dev'], 1),
                                    desc='[Evaluation]')])

        losses = torch.stack(losses)
        self.writer.add_scalar('evaluation/loss', losses.sum(), self.steps)

        self.TSNE(embeddings=torch.cat(embeddings, dim=0),
                  datasets=[
                      dataset for dataset_List in datasets
                      for dataset in dataset_List
                  ],
                  speakers=[
                      speaker for speaker_list in speakers
                      for speaker in speaker_list
                  ],
                  tag='evaluation/tsne')

        self.model.train()

    @torch.no_grad()
    def Inference_Step(self, mels):
        return Normalize(self.model(mels.to(device)),
                         samples=hp_Dict['Train']['Inference']['Samples'])

    def Inference_Epoch(self):
        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        self.model.eval()

        embeddings, datasets, speakers = zip(
            *[(self.Inference_Step(mels), datasets, speakers)
              for step, (mels, datasets, speakers) in tqdm(
                  enumerate(self.dataLoader_Dict['Inference'], 1),
                  desc='[Inference]')])

        self.TSNE(embeddings=torch.cat(embeddings, dim=0),
                  datasets=[
                      dataset for dataset_List in datasets
                      for dataset in dataset_List
                  ],
                  speakers=[
                      speaker for speaker_List in speakers
                      for speaker in speaker_List
                  ],
                  tag='infernce/tsne')

        self.model.train()

    def TSNE(self, embeddings, datasets, speakers, tag):
        scatters = TSNE(n_components=2, random_state=0).fit_transform(
            embeddings[:10 * hp_Dict['Train']['Batch']['Eval']
                       ['Pattern_per_Speaker']].cpu().numpy())
        scatters = np.reshape(
            scatters,
            [-1, hp_Dict['Train']['Batch']['Eval']['Pattern_per_Speaker'], 2])

        fig = plt.figure(figsize=(8, 8))
        for scatter, dataset, speaker in zip(
                scatters, datasets[::hp_Dict['Train']['Batch']['Eval']
                                   ['Pattern_per_Speaker']],
                speakers[::hp_Dict['Train']['Batch']['Eval']
                         ['Pattern_per_Speaker']]):
            plt.scatter(scatter[:, 0],
                        scatter[:, 1],
                        label='{}.{}'.format(dataset, speaker))
        plt.legend()
        plt.tight_layout()
        self.writer.add_figure(tag, fig, self.steps)
        plt.close(fig)

    def Load_Checkpoint(self):
        state_Dict = torch.load(os.path.join(
            hp_Dict['Checkpoint_Path'],
            'S_{}.pkl'.format(self.steps).replace('\\', '/')),
                                map_location='cpu')

        self.model.load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']
        self.epochs = state_Dict['Epochs']

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

    def Save_Checkpoint(self):
        os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True)

        state_Dict = {
            'Model': self.model.state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps,
            'Epochs': self.epochs,
        }

        torch.save(
            state_Dict,
            os.path.join(hp_Dict['Checkpoint_Path'],
                         'S_{}.pkl'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def Train(self):
        self.tqdm = tqdm(initial=self.steps,
                         total=hp_Dict['Train']['Max_Step'],
                         desc='[Training]')
        self.train_Losses = 0.0

        if hp_Dict['Train']['Initial_Inference'] and self.steps == 0:
            self.Evaluation_Epoch()
            self.Inference_Epoch()

        while self.steps < hp_Dict['Train']['Max_Step']:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')