Exemple #1
0
class Trainer:
    def __init__(self, steps=0):
        self.steps = steps
        self.epochs = 0

        self.Datset_Generate()
        self.Model_Generate()

        self.scalar_Dict = {
            'Train': defaultdict(float),
            'Evaluation': defaultdict(float),
        }

        self.writer_Dict = {
            'Train': Logger(os.path.join(hp.Log_Path, 'Train')),
            'Evaluation': Logger(os.path.join(hp.Log_Path, 'Evaluation')),
        }

        self.Load_Checkpoint()

    def Datset_Generate(self):
        train_Dataset = Dataset(
            pattern_path=hp.Train.Train_Pattern.Path,
            metadata_file=hp.Train.Train_Pattern.Metadata_File,
            accumulated_dataset_epoch=hp.Train.Train_Pattern.
            Accumulated_Dataset_Epoch,
            mel_length_min=hp.Train.Train_Pattern.Mel_Length.Min,
            mel_length_max=hp.Train.Train_Pattern.Mel_Length.Max,
            text_length_min=hp.Train.Train_Pattern.Text_Length.Min,
            text_length_max=hp.Train.Train_Pattern.Text_Length.Max,
            use_cache=hp.Train.Use_Pattern_Cache)
        dev_Dataset = Dataset(
            pattern_path=hp.Train.Eval_Pattern.Path,
            metadata_file=hp.Train.Eval_Pattern.Metadata_File,
            mel_length_min=hp.Train.Eval_Pattern.Mel_Length.Min,
            mel_length_max=hp.Train.Eval_Pattern.Mel_Length.Max,
            text_length_min=hp.Train.Eval_Pattern.Text_Length.Min,
            text_length_max=hp.Train.Eval_Pattern.Text_Length.Max,
            use_cache=hp.Train.Use_Pattern_Cache)
        inference_Dataset = Inference_Dataset(
            pattern_path=hp.Train.Inference_Pattern_File_in_Train)
        logging.info('The number of train patterns = {}.'.format(
            len(train_Dataset) //
            hp.Train.Train_Pattern.Accumulated_Dataset_Epoch))
        logging.info('The number of development patterns = {}.'.format(
            len(dev_Dataset)))
        logging.info('The number of inference patterns = {}.'.format(
            len(inference_Dataset)))

        collater = Collater()
        inference_Collater = Inference_Collater()

        self.dataLoader_Dict = {}
        self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader(
            dataset=train_Dataset,
            shuffle=True,
            collate_fn=collater,
            batch_size=hp.Train.Batch_Size,
            num_workers=hp.Train.Num_Workers,
            pin_memory=True)
        self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,
            collate_fn=collater,
            batch_size=hp.Train.Batch_Size,
            num_workers=hp.Train.Num_Workers,
            pin_memory=True)
        self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader(
            dataset=inference_Dataset,
            shuffle=False,
            collate_fn=inference_Collater,
            batch_size=hp.Inference_Batch_Size or hp.Train.Batch_Size,
            num_workers=hp.Train.Num_Workers,
            pin_memory=True)

        if hp.Mode in ['PE', 'GR']:
            self.dataLoader_Dict[
                'Prosody_Check'] = torch.utils.data.DataLoader(
                    dataset=Prosody_Check_Dataset(
                        pattern_path=hp.Train.Train_Pattern.Path,
                        metadata_file=hp.Train.Train_Pattern.Metadata_File,
                        mel_length_min=hp.Train.Train_Pattern.Mel_Length.Min,
                        mel_length_max=hp.Train.Train_Pattern.Mel_Length.Max,
                        use_cache=hp.Train.Use_Pattern_Cache),
                    shuffle=False,
                    collate_fn=Prosody_Check_Collater(),
                    batch_size=hp.Train.Batch_Size,
                    num_workers=hp.Train.Num_Workers,
                    pin_memory=True)

    def Model_Generate(self):
        self.model_Dict = {'GlowTTS': GlowTTS().to(device)}

        if not hp.Speaker_Embedding.GE2E.Checkpoint_Path is None:
            self.model_Dict['Speaker_Embedding'] = Speaker_Embedding(
                mel_dims=hp.Sound.Mel_Dim,
                lstm_size=hp.Speaker_Embedding.GE2E.LSTM.Sizes,
                lstm_stacks=hp.Speaker_Embedding.GE2E.LSTM.Stacks,
                embedding_size=hp.Speaker_Embedding.Embedding_Size,
            ).to(device)

        self.criterion_Dict = {
            'MSE': torch.nn.MSELoss().to(device),
            'MLE': MLE_Loss().to(device),
            'CE': torch.nn.CrossEntropyLoss().to(device)
        }
        self.optimizer = RAdam(params=self.model_Dict['GlowTTS'].parameters(),
                               lr=hp.Train.Learning_Rate.Initial,
                               betas=(hp.Train.ADAM.Beta1,
                                      hp.Train.ADAM.Beta2),
                               eps=hp.Train.ADAM.Epsilon,
                               weight_decay=hp.Train.Weight_Decay)
        self.scheduler = Modified_Noam_Scheduler(
            optimizer=self.optimizer, base=hp.Train.Learning_Rate.Base)

        if hp.Use_Mixed_Precision:
            self.model_Dict['GlowTTS'], self.optimizer = amp.initialize(
                models=self.model_Dict['GlowTTS'], optimizers=self.optimizer)

        logging.info(self.model_Dict['GlowTTS'])

    def Train_Step(self, tokens, token_lengths, mels, mel_lengths, speakers,
                   mels_for_ge2e, pitches):
        loss_Dict = {}

        tokens = tokens.to(device)
        token_lengths = token_lengths.to(device)
        mels = mels.to(device)
        mel_lengths = mel_lengths.to(device)
        speakers = speakers.to(device)
        mels_for_ge2e = mels_for_ge2e.to(device)
        pitches = pitches.to(device)

        z, mel_Mean, mel_Log_Std, log_Dets, log_Durations, log_Duration_Targets, _, classified_Speakers = self.model_Dict[
            'GlowTTS'](tokens=tokens,
                       token_lengths=token_lengths,
                       mels=mels,
                       mel_lengths=mel_lengths,
                       speakers=speakers,
                       mels_for_ge2e=mels_for_ge2e,
                       pitches=pitches)

        loss_Dict['MLE'] = self.criterion_Dict['MLE'](z=z,
                                                      mean=mel_Mean,
                                                      std=mel_Log_Std,
                                                      log_dets=log_Dets,
                                                      lengths=mel_lengths)
        loss_Dict['Length'] = self.criterion_Dict['MSE'](log_Durations,
                                                         log_Duration_Targets)
        loss_Dict['Total'] = loss_Dict['MLE'] + loss_Dict['Length']

        loss = loss_Dict['Total']
        if not classified_Speakers is None:
            loss_Dict['Speaker'] = self.criterion_Dict['CE'](
                classified_Speakers, speakers)
            loss = loss_Dict['Total'] + loss_Dict['Speaker']

        self.optimizer.zero_grad()
        if hp.Use_Mixed_Precision:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters=amp.master_params(
                self.optimizer),
                                           max_norm=hp.Train.Gradient_Norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=self.model_Dict['GlowTTS'].parameters(),
                max_norm=hp.Train.Gradient_Norm)
        self.optimizer.step()
        self.scheduler.step()
        self.steps += 1
        self.tqdm.update(1)

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Train']['Loss/{}'.format(tag)] += loss

    def Train_Epoch(self):
        for tokens, token_Lengths, mels, mel_Lengths, speakers, mels_for_GE2E, pitches in self.dataLoader_Dict[
                'Train']:
            self.Train_Step(tokens, token_Lengths, mels, mel_Lengths, speakers,
                            mels_for_GE2E, pitches)

            if self.steps % hp.Train.Checkpoint_Save_Interval == 0:
                self.Save_Checkpoint()

            if self.steps % hp.Train.Logging_Interval == 0:
                self.scalar_Dict['Train'] = {
                    tag: loss / hp.Train.Logging_Interval
                    for tag, loss in self.scalar_Dict['Train'].items()
                }
                self.scalar_Dict['Train'][
                    'Learning_Rate'] = self.scheduler.get_last_lr()
                self.writer_Dict['Train'].add_scalar_dict(
                    self.scalar_Dict['Train'], self.steps)
                self.scalar_Dict['Train'] = defaultdict(float)

            if self.steps % hp.Train.Evaluation_Interval == 0:
                self.Evaluation_Epoch()

            if self.steps % hp.Train.Inference_Interval == 0:
                self.Inference_Epoch()

            if self.steps >= hp.Train.Max_Step:
                return

        self.epochs += hp.Train.Train_Pattern.Accumulated_Dataset_Epoch

    @torch.no_grad()
    def Evaluation_Step(self, tokens, token_lengths, mels, mel_lengths,
                        speakers, mels_for_ge2e, pitches):
        loss_Dict = {}

        tokens = tokens.to(device)
        token_lengths = token_lengths.to(device)
        mels = mels.to(device)
        mel_lengths = mel_lengths.to(device)
        speakers = speakers.to(device)
        mels_for_ge2e = mels_for_ge2e.to(device)
        pitches = pitches.to(device)

        z, mel_Mean, mel_Log_Std, log_Dets, log_Durations, log_Duration_Targets, attentions_from_Train, classified_Speakers = self.model_Dict[
            'GlowTTS'](tokens=tokens,
                       token_lengths=token_lengths,
                       mels=mels,
                       mel_lengths=mel_lengths,
                       speakers=speakers,
                       mels_for_ge2e=mels_for_ge2e,
                       pitches=pitches)

        loss_Dict['MLE'] = self.criterion_Dict['MLE'](z=z,
                                                      mean=mel_Mean,
                                                      std=mel_Log_Std,
                                                      log_dets=log_Dets,
                                                      lengths=mel_lengths)
        loss_Dict['Length'] = self.criterion_Dict['MSE'](log_Durations,
                                                         log_Duration_Targets)
        loss_Dict['Total'] = loss_Dict['MLE'] + loss_Dict['Length']
        if not classified_Speakers is None:
            loss_Dict['Speaker'] = self.criterion_Dict['CE'](
                classified_Speakers, speakers)

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Evaluation']['Loss/{}'.format(tag)] += loss

        # For tensorboard images
        mels, _, attentions_from_Inference = self.model_Dict[
            'GlowTTS'].inference(tokens=tokens,
                                 token_lengths=token_lengths,
                                 mels_for_prosody=mels,
                                 mel_lengths_for_prosody=mel_lengths,
                                 speakers=speakers,
                                 mels_for_ge2e=mels_for_ge2e,
                                 pitches=pitches,
                                 pitch_lengths=mel_lengths,
                                 length_scale=torch.FloatTensor([1.0
                                                                 ]).to(device))

        return mels, attentions_from_Train, attentions_from_Inference, classified_Speakers

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation.'.format(self.steps))

        for model in self.model_Dict.values():
            model.eval()

        for step, (tokens, token_Lengths, mels, mel_Lengths, speakers,
                   mels_for_GE2E, pitches) in tqdm(
                       enumerate(self.dataLoader_Dict['Dev'], 1),
                       desc='[Evaluation]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Dev'].dataset) /
                           hp.Train.Batch_Size)):
            mel_Predictions, attentions_from_Train, attentions_from_Inference, classified_Speakers = self.Evaluation_Step(
                tokens, token_Lengths, mels, mel_Lengths, speakers,
                mels_for_GE2E, pitches)

        self.scalar_Dict['Evaluation'] = {
            tag: loss / step
            for tag, loss in self.scalar_Dict['Evaluation'].items()
        }
        self.writer_Dict['Evaluation'].add_scalar_dict(
            self.scalar_Dict['Evaluation'], self.steps)
        self.writer_Dict['Evaluation'].add_histogram_model(
            self.model_Dict['GlowTTS'],
            self.steps,
            delete_keywords=['layer_Dict', 'layer', 'GE2E'])
        self.scalar_Dict['Evaluation'] = defaultdict(float)

        image_Dict = {
            'Mel/Target': (mels[-1].cpu().numpy(), None),
            'Mel/Prediction': (mel_Predictions[-1].cpu().numpy(), None),
            'Attention/From_Train':
            (attentions_from_Train[-1].cpu().numpy(), None),
            'Attention/From_Inference':
            (attentions_from_Inference[-1].cpu().numpy(), None)
        }
        if not classified_Speakers is None:
            image_Dict.update({
                'Speaker/Original': (torch.nn.functional.one_hot(
                    speakers,
                    hp.Speaker_Embedding.Num_Speakers).cpu().numpy(), None),
                'Speaker/Predicted':
                (torch.softmax(classified_Speakers,
                               dim=-1).cpu().numpy(), None),
            })
        self.writer_Dict['Evaluation'].add_image_dict(image_Dict, self.steps)

        for model in self.model_Dict.values():
            model.train()

        if hp.Mode in ['PE', 'GR'
                       ] and self.steps % hp.Train.Prosody_Check_Interval == 0:
            self.Prosody_Check_Epoch()

    @torch.no_grad()
    def Inference_Step(self,
                       tokens,
                       token_lengths,
                       mels_for_prosody,
                       mel_lengths_for_prosody,
                       speakers,
                       mels_for_ge2e,
                       pitches,
                       pitch_lengths,
                       length_scales,
                       labels,
                       texts,
                       start_index=0,
                       tag_step=False,
                       tag_index=False):
        tokens = tokens.to(device)
        token_lengths = token_lengths.to(device)
        mels_for_prosody = mels_for_prosody.to(device)
        mel_lengths_for_prosody = mel_lengths_for_prosody.to(device)
        speakers = speakers.to(device)
        mels_for_ge2e = mels_for_ge2e.to(device)
        pitches = pitches.to(device)
        length_scales = length_scales.to(device)

        mels, mel_Lengths, attentions = self.model_Dict['GlowTTS'].inference(
            tokens=tokens,
            token_lengths=token_lengths,
            mels_for_prosody=mels_for_prosody,
            mel_lengths_for_prosody=mel_lengths_for_prosody,
            speakers=speakers,
            mels_for_ge2e=mels_for_ge2e,
            pitches=pitches,
            pitch_lengths=pitch_lengths,
            length_scale=length_scales)

        files = []
        for index, label in enumerate(labels):
            tags = []
            if tag_step: tags.append('Step-{}'.format(self.steps))
            tags.append(label)
            if tag_index: tags.append('IDX_{}'.format(index + start_index))
            files.append('.'.join(tags))

        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps),
                                 'PNG').replace('\\', '/'),
                    exist_ok=True)
        for index, (mel, mel_Length, attention, label, text, length_Scale,
                    file) in enumerate(
                        zip(mels.cpu().numpy(),
                            mel_Lengths.cpu().numpy(),
                            attentions.cpu().numpy(), labels, texts,
                            length_scales, files)):
            mel = mel[:, :mel_Length]
            attention = attention[:len(text) + 2, :mel_Length]

            new_Figure = plt.figure(figsize=(20, 5 * 3), dpi=100)
            plt.subplot2grid((3, 1), (0, 0))
            plt.imshow(mel, aspect='auto', origin='lower')
            plt.title(
                'Mel    Label: {}    Text: {}    Length scale: {:.3f}'.format(
                    label, text if len(text) < 70 else text[:70] + '…',
                    length_Scale))
            plt.colorbar()
            plt.subplot2grid((3, 1), (1, 0), rowspan=2)
            plt.imshow(attention,
                       aspect='auto',
                       origin='lower',
                       interpolation='none')
            plt.title(
                'Attention    Label: {}    Text: {}    Length scale: {:.3f}'.
                format(label, text if len(text) < 70 else text[:70] + '…',
                       length_Scale))
            plt.yticks(range(len(text) + 2), ['<S>'] + list(text) + ['<E>'],
                       fontsize=10)
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(
                os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps),
                             'PNG', '{}.PNG'.format(file)).replace('\\', '/'))
            plt.close(new_Figure)

        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps),
                                 'NPY').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY',
                                 'Mel').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY',
                                 'Attention').replace('\\', '/'),
                    exist_ok=True)

        for index, (mel, mel_Length, file) in enumerate(
                zip(mels.cpu().numpy(),
                    mel_Lengths.cpu().numpy(), files)):
            mel = mel[:, :mel_Length]
            attention = attention[:len(text) + 2, :mel_Length]

            np.save(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY', 'Mel',
                                 file).replace('\\', '/'),
                    mel.T,
                    allow_pickle=False)
            np.save(os.path.join(hp.Inference_Path,
                                 'Step-{}'.format(self.steps), 'NPY',
                                 'Attention', file).replace('\\', '/'),
                    attentions.cpu().numpy()[index],
                    allow_pickle=False)

    def Inference_Epoch(self):
        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        for model in self.model_Dict.values():
            model.eval()

        for step, (tokens, token_Lengths, mels_for_Prosody,
                   mel_Lengths_for_Prosody, speakers, mels_for_GE2E, pitches,
                   pitch_Lengths, length_Scales, labels, texts) in tqdm(
                       enumerate(self.dataLoader_Dict['Inference']),
                       desc='[Inference]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Inference'].dataset) /
                           (hp.Inference_Batch_Size or hp.Train.Batch_Size))):
            self.Inference_Step(
                tokens,
                token_Lengths,
                mels_for_Prosody,
                mel_Lengths_for_Prosody,
                speakers,
                mels_for_GE2E,
                pitches,
                pitch_Lengths,
                length_Scales,
                labels,
                texts,
                start_index=step *
                (hp.Inference_Batch_Size or hp.Train.Batch_Size))

        for model in self.model_Dict.values():
            model.train()

    @torch.no_grad()
    def Prosody_Check_Step(self, mels, mel_lengths):
        mels = mels.to(device)
        mel_lengths = mel_lengths.to(device)
        prosodies = self.model_Dict['GlowTTS'].layer_Dict['Prosody_Encoder'](
            mels, mel_lengths)

        return prosodies

    def Prosody_Check_Epoch(self):
        logging.info('(Steps: {}) Start prosody check.'.format(self.steps))

        for model in self.model_Dict.values():
            model.eval()

        prosodies, labels = zip(
            *[(self.Prosody_Check_Step(mels, mel_Lengths), labels)
              for mels, mel_Lengths, labels in tqdm(
                  self.dataLoader_Dict['Prosody_Check'],
                  desc='[Prosody_Check]',
                  total=math.ceil(
                      len(self.dataLoader_Dict['Prosody_Check'].dataset) /
                      hp.Train.Batch_Size))])
        prosodies = torch.cat(prosodies, dim=0)
        labels = [label for sub_labels in labels for label in sub_labels]

        self.writer_Dict['Evaluation'].add_embedding(prosodies.cpu().numpy(),
                                                     metadata=labels,
                                                     global_step=self.steps,
                                                     tag='Prosodies')

        for model in self.model_Dict.values():
            model.train()

    def Load_Checkpoint(self):
        if self.steps == 0:
            paths = [
                os.path.join(root, file).replace('\\', '/')
                for root, _, files in os.walk(hp.Checkpoint_Path)
                for file in files if os.path.splitext(file)[1] == '.pt'
            ]
            if len(paths) > 0:
                path = max(paths, key=os.path.getctime)
            else:
                return  # Initial training
        else:
            path = os.path.join(
                hp.Checkpoint_Path,
                'S_{}.pt'.format(self.steps).replace('\\', '/'))

        state_Dict = torch.load(path, map_location='cpu')
        self.model_Dict['GlowTTS'].load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']
        self.epochs = state_Dict['Epochs']

        if hp.Use_Mixed_Precision:
            if not 'AMP' in state_Dict.keys():
                logging.info(
                    'No AMP state dict is in the checkpoint. Model regards this checkpoint is trained without mixed precision.'
                )
            else:
                amp.load_state_dict(state_Dict['AMP'])

        for flow in self.model_Dict['GlowTTS'].layer_Dict[
                'Decoder'].layer_Dict['Flows']:
            flow.layers[
                0].initialized = True  # Activation_Norm is already initialized when checkpoint is loaded.

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

        if 'GE2E' in self.model_Dict['GlowTTS'].layer_Dict.keys(
        ) and self.steps == 0:
            self.GE2E_Load_Checkpoint()

    def Save_Checkpoint(self):
        os.makedirs(hp.Checkpoint_Path, exist_ok=True)

        state_Dict = {
            'Model': self.model_Dict['GlowTTS'].state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps,
            'Epochs': self.epochs,
        }
        if hp.Use_Mixed_Precision:
            state_Dict['AMP'] = amp.state_dict()

        torch.save(
            state_Dict,
            os.path.join(hp.Checkpoint_Path,
                         'S_{}.pt'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def GE2E_Load_Checkpoint(self):
        state_Dict = torch.load(hp.Speaker_Embedding.GE2E.Checkpoint_Path,
                                map_location='cpu')
        self.model_Dict['GlowTTS'].layer_Dict['GE2E'].load_state_dict(
            state_Dict['Model'])
        logging.info('Speaker embedding checkpoint \'{}\' loaded.'.format(
            hp.Speaker_Embedding.GE2E.Checkpoint_Path))

    def Train(self):
        hp_Path = os.path.join(hp.Checkpoint_Path,
                               'Hyper_Parameters.yaml').replace('\\', '/')
        if not os.path.exists(hp_Path):
            from shutil import copyfile
            os.makedirs(hp.Checkpoint_Path, exist_ok=True)
            copyfile('Hyper_Parameters.yaml', hp_Path)

        if self.steps == 0:
            self.Evaluation_Epoch()

        if hp.Train.Initial_Inference:
            self.Inference_Epoch()

        self.tqdm = tqdm(initial=self.steps,
                         total=hp.Train.Max_Step,
                         desc='[Training]')

        while self.steps < hp.Train.Max_Step:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')
Exemple #2
0
class Trainer:
    def __init__(self, hp_path, steps=0):
        self.hp_Path = hp_path
        self.gpu_id = int(os.getenv('RANK', '0'))
        self.num_gpus = int(os.getenv("WORLD_SIZE", '1'))

        self.hp = Recursive_Parse(
            yaml.load(open(self.hp_Path, encoding='utf-8'),
                      Loader=yaml.Loader))

        if not torch.cuda.is_available():
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:{}'.format(self.gpu_id))
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True
            torch.cuda.set_device(self.gpu_id)

        self.steps = steps

        self.Datset_Generate()
        self.Model_Generate()
        self.Load_Checkpoint()
        self._Set_Distribution()

        self.scalar_dict = {
            'Train': defaultdict(float),
            'Evaluation': defaultdict(float),
        }

        self.writer_Dict = {
            'Train': Logger(os.path.join(self.hp.Log_Path, 'Train')),
            'Evaluation': Logger(os.path.join(self.hp.Log_Path, 'Evaluation')),
        }

    def Datset_Generate(self):
        token_dict = yaml.load(open(self.hp.Token_Path), Loader=yaml.Loader)

        train_dataset = Dataset(
            token_dict=token_dict,
            pattern_path=self.hp.Train.Train_Pattern.Path,
            metadata_file=self.hp.Train.Train_Pattern.Metadata_File)
        dev_dataset = Dataset(
            token_dict=token_dict,
            pattern_path=self.hp.Train.Eval_Pattern.Path,
            metadata_file=self.hp.Train.Eval_Pattern.Metadata_File)
        inference_dataset = Inference_Dataset(
            token_dict=token_dict,
            pattern_paths=self.hp.Train.Inference_Pattern_in_Train)

        if self.gpu_id == 0:
            logging.info('The number of train patterns = {}.'.format(
                len(train_dataset)))
            logging.info('The number of development patterns = {}.'.format(
                len(dev_dataset)))
            logging.info('The number of inference patterns = {}.'.format(
                len(inference_dataset)))

        collater = Collater(token_dict=token_dict)
        inference_collater = Inference_Collater(token_dict=token_dict)

        self.dataloader_dict = {}
        self.dataloader_dict['Train'] = torch.utils.data.DataLoader(
            dataset= train_dataset,
            sampler= torch.utils.data.DistributedSampler(train_dataset, shuffle= True) \
                     if self.hp.Use_Multi_GPU else \
                     torch.utils.data.RandomSampler(train_dataset),
            collate_fn= collater,
            batch_size= self.hp.Train.Batch_Size,
            num_workers= self.hp.Train.Num_Workers,
            pin_memory= True
            )
        self.dataloader_dict['Dev'] = torch.utils.data.DataLoader(
            dataset= dev_dataset,
            sampler= torch.utils.data.DistributedSampler(dev_dataset, shuffle= True) \
                     if self.num_gpus > 1 else \
                     torch.utils.data.RandomSampler(dev_dataset),
            collate_fn= collater,
            batch_size= self.hp.Train.Batch_Size,
            num_workers= self.hp.Train.Num_Workers,
            pin_memory= True
            )
        self.dataloader_dict['Inference'] = torch.utils.data.DataLoader(
            dataset=inference_dataset,
            sampler=torch.utils.data.SequentialSampler(inference_dataset),
            collate_fn=inference_collater,
            batch_size=self.hp.Inference_Batch_Size
            or self.hp.Train.Batch_Size,
            num_workers=self.hp.Train.Num_Workers,
            pin_memory=True)

    def Model_Generate(self):
        self.model = TacoSinger(self.hp).to(self.device)
        self.criterion_dict = {
            'MSE': torch.nn.MSELoss().to(self.device),
            'GAL': Guided_Attention_Loss(),
        }
        self.optimizer = RAdam(params=self.model.parameters(),
                               lr=self.hp.Train.Learning_Rate.Initial,
                               betas=(self.hp.Train.ADAM.Beta1,
                                      self.hp.Train.ADAM.Beta2),
                               eps=self.hp.Train.ADAM.Epsilon,
                               weight_decay=self.hp.Train.Weight_Decay)
        self.scheduler = Modified_Noam_Scheduler(
            optimizer=self.optimizer, base=self.hp.Train.Learning_Rate.Base)

        self.scaler = torch.cuda.amp.GradScaler(
            enabled=self.hp.Use_Mixed_Precision)

        if self.gpu_id == 0:
            logging.info(self.model)

    def Train_Step(self, tokens, notes, durations, token_lengths, features,
                   feature_lengths):
        loss_dict = {}
        tokens = tokens.to(self.device, non_blocking=True)
        notes = notes.to(self.device, non_blocking=True)
        durations = durations.to(self.device, non_blocking=True)
        token_lengths = token_lengths.to(self.device, non_blocking=True)
        features = features.to(self.device, non_blocking=True)
        feature_lengths = feature_lengths.to(self.device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=self.hp.Use_Mixed_Precision):
            pre_features, post_features, alignments = self.model(
                tokens=tokens,
                notes=notes,
                durations=durations,
                token_lengths=token_lengths,
                features=features,
                feature_lengths=feature_lengths,
                is_training=True)

            loss_dict['Pre'] = self.criterion_dict['MSE'](pre_features,
                                                          features)
            loss_dict['Post'] = self.criterion_dict['MSE'](post_features,
                                                           features)
            loss_dict['Guided_Attention'] = self.criterion_dict['GAL'](
                alignments, feature_lengths, token_lengths)
            loss_dict['Total'] = loss_dict['Pre'] + loss_dict[
                'Post'] + loss_dict['Guided_Attention']

        self.optimizer.zero_grad()
        self.scaler.scale(loss_dict['Total']).backward()
        if self.hp.Train.Gradient_Norm > 0.0:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(
                parameters=self.model.parameters(),
                max_norm=self.hp.Train.Gradient_Norm)

        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.scheduler.step()
        self.steps += 1
        self.tqdm.update(1)

        for tag, loss in loss_dict.items():
            loss = reduce_tensor(
                loss.data,
                self.num_gpus).item() if self.num_gpus > 1 else loss.item()
            self.scalar_dict['Train']['Loss/{}'.format(tag)] += loss

    def Train_Epoch(self):
        for tokens, notes, durations, token_lengths, features, feature_lengths in self.dataloader_dict[
                'Train']:
            self.Train_Step(tokens, notes, durations, token_lengths, features,
                            feature_lengths)

            if self.steps % self.hp.Train.Checkpoint_Save_Interval == 0:
                self.Save_Checkpoint()

            if self.steps % self.hp.Train.Logging_Interval == 0:
                self.scalar_dict['Train'] = {
                    tag: loss / self.hp.Train.Logging_Interval
                    for tag, loss in self.scalar_dict['Train'].items()
                }
                self.scalar_dict['Train'][
                    'Learning_Rate'] = self.scheduler.get_last_lr()
                self.writer_Dict['Train'].add_scalar_dict(
                    self.scalar_dict['Train'], self.steps)
                self.scalar_dict['Train'] = defaultdict(float)

            if self.steps % self.hp.Train.Evaluation_Interval == 0:
                self.Evaluation_Epoch()

            if self.steps % self.hp.Train.Inference_Interval == 0:
                self.Inference_Epoch()

            if self.steps >= self.hp.Train.Max_Step:
                return

    @torch.no_grad()
    def Evaluation_Step(self, tokens, notes, durations, token_lengths,
                        features, feature_lengths):
        loss_dict = {}
        tokens = tokens.to(self.device, non_blocking=True)
        notes = notes.to(self.device, non_blocking=True)
        durations = durations.to(self.device, non_blocking=True)
        token_lengths = token_lengths.to(self.device, non_blocking=True)
        features = features.to(self.device, non_blocking=True)
        feature_lengths = feature_lengths.to(self.device, non_blocking=True)

        pre_features, post_features, alignments = self.model(
            tokens=tokens,
            notes=notes,
            durations=durations,
            token_lengths=token_lengths,
            features=features,
            feature_lengths=feature_lengths,
            is_training=True)

        loss_dict['Pre'] = self.criterion_dict['MSE'](pre_features, features)
        loss_dict['Post'] = self.criterion_dict['MSE'](post_features, features)
        loss_dict['Guided_Attention'] = self.criterion_dict['GAL'](
            alignments, feature_lengths, token_lengths)
        loss_dict['Total'] = loss_dict['Pre'] + loss_dict['Post'] + loss_dict[
            'Guided_Attention']

        for tag, loss in loss_dict.items():
            loss = reduce_tensor(
                loss.data,
                self.num_gpus).item() if self.num_gpus > 1 else loss.item()
            self.scalar_dict['Evaluation']['Loss/{}'.format(tag)] += loss

        return pre_features, post_features, alignments

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation in GPU {}.'.format(
            self.steps, self.gpu_id))

        self.model.eval()

        for step, (tokens, notes, durations, token_lengths, features,
                   feature_lengths) in tqdm(
                       enumerate(self.dataloader_dict['Dev'], 1),
                       desc='[Evaluation]',
                       total=math.ceil(
                           len(self.dataloader_dict['Dev'].dataset) /
                           self.hp.Train.Batch_Size / self.num_gpus)):
            pre_features, post_features, alignments = self.Evaluation_Step(
                tokens, notes, durations, token_lengths, features,
                feature_lengths)

        if self.gpu_id == 0:
            self.scalar_dict['Evaluation'] = {
                tag: loss / step
                for tag, loss in self.scalar_dict['Evaluation'].items()
            }
            self.writer_Dict['Evaluation'].add_scalar_dict(
                self.scalar_dict['Evaluation'], self.steps)
            self.writer_Dict['Evaluation'].add_histogram_model(
                self.model, 'TacoSinger', self.steps, delete_keywords=[])

            image_dict = {
                'Feature/Target':
                (features[-1].cpu().numpy(), None, 'auto', None),
                'Feature/Pre':
                (pre_features[-1].cpu().numpy(), None, 'auto', None),
                'Feature/Post':
                (post_features[-1].cpu().numpy(), None, 'auto', None),
                'Alignment': (alignments[-1].cpu().numpy(), None, 'auto', None)
            }
            self.writer_Dict['Evaluation'].add_image_dict(
                image_dict, self.steps)

        self.scalar_dict['Evaluation'] = defaultdict(float)

        self.model.train()

    @torch.no_grad()
    def Inference_Step(self,
                       tokens,
                       notes,
                       durations,
                       token_lengths,
                       feature_lengths,
                       texts,
                       decomposed_texts,
                       restoring,
                       start_index=0,
                       tag_step=False):
        tokens = tokens.to(self.device, non_blocking=True)
        notes = notes.to(self.device, non_blocking=True)
        durations = durations.to(self.device, non_blocking=True)
        token_lengths = token_lengths.to(self.device, non_blocking=True)
        feature_lengths = feature_lengths.to(self.device, non_blocking=True)

        _, post_features, alignments = self.model(
            tokens=tokens,
            notes=notes,
            durations=durations,
            token_lengths=token_lengths,
            feature_lengths=feature_lengths,
            is_training=False)

        post_features = torch.stack(
            [post_features[index] for index in restoring], dim=0)
        alignments = torch.stack([alignments[index] for index in restoring],
                                 dim=0)
        feature_lengths = torch.stack(
            [feature_lengths[index] for index in restoring], dim=0)
        texts = [texts[index] for index in restoring]
        decomposed_texts = [decomposed_texts[index] for index in restoring]

        audios = []
        for feature, length in zip(post_features, feature_lengths):
            feature = spectral_de_normalize_torch(
                feature[:, :length]).cpu().numpy()
            audio = griffinlim(feature)
            audios.append(audio)
        audios = [(audio / np.abs(audio).max() * 32767.5).astype(np.int16)
                  for audio in audios]

        files = []
        for index in range(post_features.size(0)):
            tags = []
            if tag_step: tags.append('Step-{}'.format(self.steps))
            tags.append('IDX_{}'.format(index + start_index))
            files.append('.'.join(tags))

        os.makedirs(os.path.join(self.hp.Inference_Path,
                                 'Step-{}'.format(self.steps),
                                 'PNG').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(self.hp.Inference_Path,
                                 'Step-{}'.format(self.steps),
                                 'WAV').replace('\\', '/'),
                    exist_ok=True)
        for index, (feature, alignment, text, decomposed_text, audio,
                    file) in enumerate(
                        zip(post_features.cpu().numpy(),
                            alignments.cpu().numpy(), texts, decomposed_texts,
                            audios, files)):
            title = 'Text: {}'.format(text if len(text) < 90 else text[:90] +
                                      '…')
            new_Figure = plt.figure(figsize=(20, 5 * 3), dpi=100)
            plt.subplot2grid((4, 1), (0, 0))
            plt.imshow(feature, aspect='auto', origin='lower')
            plt.title('Feature    {}'.format(title))
            plt.colorbar()
            plt.subplot2grid((4, 1), (1, 0), rowspan=2)
            plt.imshow(alignment[:len(decomposed_text) + 2],
                       aspect='auto',
                       origin='lower')
            plt.title('Alignment    {}'.format(title))
            plt.yticks(range(len(decomposed_text)),
                       list(decomposed_text),
                       fontsize=10)
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(
                os.path.join(self.hp.Inference_Path,
                             'Step-{}'.format(self.steps), 'PNG',
                             '{}.png'.format(file)).replace('\\', '/'))
            plt.close(new_Figure)

            wavfile.write(
                os.path.join(self.hp.Inference_Path,
                             'Step-{}'.format(self.steps), 'WAV',
                             '{}.wav'.format(file)).replace('\\', '/'),
                self.hp.Sound.Sample_Rate, audio)

    def Inference_Epoch(self):
        if self.gpu_id != 0:
            return

        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        self.model.eval()

        batch_size = self.hp.Inference_Batch_Size or self.hp.Train.Batch_Size
        for step, (tokens, notes, durations, token_lengths, feature_lengths,
                   texts, decomposed_texts, restoring) in tqdm(
                       enumerate(self.dataloader_dict['Inference']),
                       desc='[Inference]',
                       total=math.ceil(
                           len(self.dataloader_dict['Inference'].dataset) /
                           batch_size)):
            self.Inference_Step(tokens,
                                notes,
                                durations,
                                token_lengths,
                                feature_lengths,
                                texts,
                                decomposed_texts,
                                restoring,
                                start_index=step * batch_size)

        self.model.train()

    def Load_Checkpoint(self):
        if self.steps == 0:
            paths = [
                os.path.join(root, file).replace('\\', '/')
                for root, _, files in os.walk(self.hp.Checkpoint_Path)
                for file in files if os.path.splitext(file)[1] == '.pt'
            ]
            if len(paths) > 0:
                path = max(paths, key=os.path.getctime)
            else:
                return  # Initial training
        else:
            path = os.path.join(
                self.hp.Checkpoint_Path,
                'S_{}.pt'.format(self.steps).replace('\\', '/'))

        state_dict = torch.load(path, map_location='cpu')
        self.model.load_state_dict(state_dict['Model'])
        self.optimizer.load_state_dict(state_dict['Optimizer'])
        self.scheduler.load_state_dict(state_dict['Scheduler'])
        self.steps = state_dict['Steps']

        logging.info('Checkpoint loaded at {} steps in GPU {}.'.format(
            self.steps, self.gpu_id))

    def Save_Checkpoint(self):
        if self.gpu_id != 0:
            return

        os.makedirs(self.hp.Checkpoint_Path, exist_ok=True)
        state_Dict = {
            'Model': self.model.state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps
        }

        torch.save(
            state_Dict,
            os.path.join(self.hp.Checkpoint_Path,
                         'S_{}.pt'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def _Set_Distribution(self):
        if self.num_gpus > 1:
            self.model = apply_gradient_allreduce(self.model)

    def Train(self):
        hp_path = os.path.join(self.hp.Checkpoint_Path,
                               'Hyper_Parameters.yaml').replace('\\', '/')
        if not os.path.exists(hp_path):
            os.makedirs(self.hp.Checkpoint_Path, exist_ok=True)
            yaml.dump(self.hp, open(hp_path, 'w'))

        if self.steps == 0:
            self.Evaluation_Epoch()

        if self.hp.Train.Initial_Inference:
            self.Inference_Epoch()

        self.tqdm = tqdm(initial=self.steps,
                         total=self.hp.Train.Max_Step,
                         desc='[Training]')

        while self.steps < self.hp.Train.Max_Step:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')
Exemple #3
0
class Trainer:
    def __init__(self, steps=0):
        self.steps = steps
        self.epochs = 0

        self.Datset_Generate()
        self.Model_Generate()

        self.writer = SummaryWriter(hp_Dict['Log_Path'])

        if self.steps > 0:
            self.Load_Checkpoint()

    def Datset_Generate(self):
        train_Dataset = Train_Dataset()
        dev_Dataset = Dev_Dataset()
        logging.info('The number of train files = {}.'.format(
            len(train_Dataset)))
        logging.info('The number of development files = {}.'.format(
            len(dev_Dataset)))

        train_Collater = Train_Collater()
        dev_Collater = Dev_Collater()
        inference_Collater = Inference_Collater()

        self.dataLoader_Dict = {}
        self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader(
            dataset=train_Dataset,
            shuffle=True,
            collate_fn=train_Collater,
            batch_size=hp_Dict['Train']['Batch']['Train']['Speaker'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,
            collate_fn=dev_Collater,
            batch_size=hp_Dict['Train']['Batch']['Eval']['Speaker'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,
            collate_fn=inference_Collater,
            batch_size=hp_Dict['Train']['Batch']['Eval']['Speaker'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)

    def Model_Generate(self):
        self.model = Encoder(
            mel_dims=hp_Dict['Sound']['Mel_Dim'],
            lstm_size=hp_Dict['Encoder']['LSTM']['Sizes'],
            lstm_stacks=hp_Dict['Encoder']['LSTM']['Stacks'],
            embedding_size=hp_Dict['Encoder']['Embedding_Size'],
        ).to(device)
        self.criterion = GE2E_Loss().to(device)
        self.optimizer = RAdam(
            params=self.model.parameters(),
            lr=hp_Dict['Train']['Learning_Rate']['Initial'],
            eps=hp_Dict['Train']['Learning_Rate']['Epsilon'],
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=self.optimizer,
            step_size=hp_Dict['Train']['Learning_Rate']['Decay_Step'],
            gamma=hp_Dict['Train']['Learning_Rate']['Decay_Rate'],
        )

        logging.info(self.model)

    def Train_Step(self, mels):
        mels = mels.to(device)
        embeddings = self.model(mels)
        loss = self.criterion(
            embeddings,
            hp_Dict['Train']['Batch']['Train']['Pattern_per_Speaker'], device)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            parameters=self.model.parameters(),
            max_norm=hp_Dict['Train']['Gradient_Norm'])
        self.optimizer.step()
        self.scheduler.step()

        self.steps += 1
        self.tqdm.update(1)

        self.train_Losses += loss

    def Train_Epoch(self):
        for mels in self.dataLoader_Dict['Train']:
            self.Train_Step(mels)

            if self.steps % hp_Dict['Train']['Checkpoint_Save_Interval'] == 0:
                self.Save_Checkpoint()

            if self.steps % hp_Dict['Train']['Logging_Interval'] == 0:
                self.writer.add_scalar(
                    'train/loss',
                    self.train_Losses / hp_Dict['Train']['Logging_Interval'],
                    self.steps)
                self.train_Losses = 0.0

            if self.steps % hp_Dict['Train']['Evaluation_Interval'] == 0:
                self.Evaluation_Epoch()
                self.Inference_Epoch()

            if self.steps >= hp_Dict['Train']['Max_Step']:
                return

        self.epochs += 1

    @torch.no_grad()
    def Evaluation_Step(self, mels):
        mels = mels.to(device)
        embeddings = self.model(mels)
        loss = self.criterion(
            embeddings,
            hp_Dict['Train']['Batch']['Eval']['Pattern_per_Speaker'], device)

        return embeddings, loss

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation.'.format(self.steps))

        self.model.eval()

        embeddings, losses, datasets, speakers = zip(
            *[(*self.Evaluation_Step(mels), datasets, speakers)
              for step, (
                  mels, datasets,
                  speakers) in tqdm(enumerate(self.dataLoader_Dict['Dev'], 1),
                                    desc='[Evaluation]')])

        losses = torch.stack(losses)
        self.writer.add_scalar('evaluation/loss', losses.sum(), self.steps)

        self.TSNE(embeddings=torch.cat(embeddings, dim=0),
                  datasets=[
                      dataset for dataset_List in datasets
                      for dataset in dataset_List
                  ],
                  speakers=[
                      speaker for speaker_list in speakers
                      for speaker in speaker_list
                  ],
                  tag='evaluation/tsne')

        self.model.train()

    @torch.no_grad()
    def Inference_Step(self, mels):
        return Normalize(self.model(mels.to(device)),
                         samples=hp_Dict['Train']['Inference']['Samples'])

    def Inference_Epoch(self):
        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        self.model.eval()

        embeddings, datasets, speakers = zip(
            *[(self.Inference_Step(mels), datasets, speakers)
              for step, (mels, datasets, speakers) in tqdm(
                  enumerate(self.dataLoader_Dict['Inference'], 1),
                  desc='[Inference]')])

        self.TSNE(embeddings=torch.cat(embeddings, dim=0),
                  datasets=[
                      dataset for dataset_List in datasets
                      for dataset in dataset_List
                  ],
                  speakers=[
                      speaker for speaker_List in speakers
                      for speaker in speaker_List
                  ],
                  tag='infernce/tsne')

        self.model.train()

    def TSNE(self, embeddings, datasets, speakers, tag):
        scatters = TSNE(n_components=2, random_state=0).fit_transform(
            embeddings[:10 * hp_Dict['Train']['Batch']['Eval']
                       ['Pattern_per_Speaker']].cpu().numpy())
        scatters = np.reshape(
            scatters,
            [-1, hp_Dict['Train']['Batch']['Eval']['Pattern_per_Speaker'], 2])

        fig = plt.figure(figsize=(8, 8))
        for scatter, dataset, speaker in zip(
                scatters, datasets[::hp_Dict['Train']['Batch']['Eval']
                                   ['Pattern_per_Speaker']],
                speakers[::hp_Dict['Train']['Batch']['Eval']
                         ['Pattern_per_Speaker']]):
            plt.scatter(scatter[:, 0],
                        scatter[:, 1],
                        label='{}.{}'.format(dataset, speaker))
        plt.legend()
        plt.tight_layout()
        self.writer.add_figure(tag, fig, self.steps)
        plt.close(fig)

    def Load_Checkpoint(self):
        state_Dict = torch.load(os.path.join(
            hp_Dict['Checkpoint_Path'],
            'S_{}.pkl'.format(self.steps).replace('\\', '/')),
                                map_location='cpu')

        self.model.load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']
        self.epochs = state_Dict['Epochs']

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

    def Save_Checkpoint(self):
        os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True)

        state_Dict = {
            'Model': self.model.state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps,
            'Epochs': self.epochs,
        }

        torch.save(
            state_Dict,
            os.path.join(hp_Dict['Checkpoint_Path'],
                         'S_{}.pkl'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def Train(self):
        self.tqdm = tqdm(initial=self.steps,
                         total=hp_Dict['Train']['Max_Step'],
                         desc='[Training]')
        self.train_Losses = 0.0

        if hp_Dict['Train']['Initial_Inference'] and self.steps == 0:
            self.Evaluation_Epoch()
            self.Inference_Epoch()

        while self.steps < hp_Dict['Train']['Max_Step']:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')
Exemple #4
0
class Trainer:
    def __init__(self, steps=0):
        self.steps = steps
        self.epochs = 0

        self.Datset_Generate()
        self.Model_Generate()

        self.scalar_Dict = {
            'Train': defaultdict(float),
            'Evaluation': defaultdict(float),
        }

        self.writer_Dict = {
            'Train': Logger(os.path.join(hp_Dict['Log_Path'], 'Train')),
            'Evaluation':
            Logger(os.path.join(hp_Dict['Log_Path'], 'Evaluation')),
        }

        self.Load_Checkpoint()

    def Datset_Generate(self):
        train_Dataset = Train_Dataset()
        accumulation_Dataset = Accumulation_Dataset()
        dev_Dataset = Dev_Dataset()
        inference_Dataset = Inference_Dataset()
        logging.info('The number of base train files = {}.'.format(
            len(train_Dataset) //
            hp_Dict['Train']['Train_Pattern']['Accumulated_Dataset_Epoch']))
        logging.info('The number of development patterns = {}.'.format(
            len(dev_Dataset)))
        logging.info('The number of inference patterns = {}.'.format(
            len(inference_Dataset)))

        collater = Collater()
        accumulation_Collater = Accumulation_Collater()
        inference_Collater = Inference_Collater()

        self.dataLoader_Dict = {}
        self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader(
            dataset=train_Dataset,
            shuffle=True,
            collate_fn=collater,
            batch_size=hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Accumulation'] = torch.utils.data.DataLoader(
            dataset=accumulation_Dataset,
            shuffle=False,
            collate_fn=accumulation_Collater,
            batch_size=hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader(
            dataset=dev_Dataset,
            shuffle=True,  # to write tensorboard.
            collate_fn=collater,
            batch_size=hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)
        self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader(
            dataset=inference_Dataset,
            shuffle=False,  # to write tensorboard.
            collate_fn=inference_Collater,
            batch_size=hp_Dict['Inference_Batch_Size']
            or hp_Dict['Train']['Batch_Size'],
            num_workers=hp_Dict['Train']['Num_Workers'],
            pin_memory=True)

    def Model_Generate(self):
        self.model = PVCGAN().to(device)
        self.criterion_Dict = {
            'STFT':
            MultiResolutionSTFTLoss(
                fft_sizes=hp_Dict['STFT_Loss_Resolution']['FFT_Sizes'],
                shift_lengths=hp_Dict['STFT_Loss_Resolution']['Shfit_Lengths'],
                win_lengths=hp_Dict['STFT_Loss_Resolution']['Win_Lengths'],
            ).to(device),
            'MSE':
            torch.nn.MSELoss().to(device),
            'CE':
            torch.nn.CrossEntropyLoss().to(device),
            'MAE':
            torch.nn.L1Loss().to(device)
        }

        self.optimizer = RAdam(params=self.model.parameters(),
                               lr=hp_Dict['Train']['Learning_Rate']['Initial'],
                               betas=(hp_Dict['Train']['ADAM']['Beta1'],
                                      hp_Dict['Train']['ADAM']['Beta2']),
                               eps=hp_Dict['Train']['ADAM']['Epsilon'],
                               weight_decay=hp_Dict['Train']['Weight_Decay'])
        self.scheduler = Modified_Noam_Scheduler(
            optimizer=self.optimizer,
            base=hp_Dict['Train']['Learning_Rate']['Base'])

        if hp_Dict['Use_Mixed_Precision']:
            self.model, self.optimizer = amp.initialize(
                models=self.model, optimizers=self.optimizer)

        logging.info(self.model)

    def Train_Step(self, audios, mels, pitches, audio_Singers, mel_Singers,
                   noises):
        loss_Dict = {}

        audios = audios.to(device, non_blocking=True)
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        audio_Singers = audio_Singers.to(device, non_blocking=True)
        mel_Singers = mel_Singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        # Generator
        fakes, predicted_Singers, predicted_Pitches, fakes_Discriminations, reals_Discriminations = self.model(
            mels=mels,
            pitches=pitches,
            singers=audio_Singers,
            noises=noises,
            discrimination=self.steps >=
            hp_Dict['Train']['Discriminator_Delay'],
            reals=audios)

        loss_Dict['Generator/Spectral_Convergence'], loss_Dict[
            'Generator/Magnitude'] = self.criterion_Dict['STFT'](fakes, audios)
        loss_Dict['Generator'] = loss_Dict[
            'Generator/Spectral_Convergence'] + loss_Dict['Generator/Magnitude']

        loss_Dict['Confuser/Singer'] = self.criterion_Dict['CE'](
            predicted_Singers, mel_Singers)
        loss_Dict['Confuser/Pitch'] = self.criterion_Dict['MAE'](
            predicted_Pitches, pitches)
        loss_Dict['Confuser'] = loss_Dict['Confuser/Singer'] + loss_Dict[
            'Confuser/Pitch']
        loss = loss_Dict['Generator'] + loss_Dict['Confuser']

        if self.steps >= hp_Dict['Train']['Discriminator_Delay']:
            loss_Dict['Discriminator/Fake'] = self.criterion_Dict['MSE'](
                fakes_Discriminations, torch.zeros_like(fakes_Discriminations)
            )  # Discriminator thinks that 0 is correct.
            loss_Dict['Discriminator/Real'] = self.criterion_Dict['MSE'](
                reals_Discriminations, torch.ones_like(reals_Discriminations)
            )  # Discriminator thinks that 1 is correct.

            loss_Dict['Discriminator'] = loss_Dict[
                'Discriminator/Fake'] + loss_Dict['Discriminator/Real']
            loss += loss_Dict['Discriminator']

        self.optimizer.zero_grad()

        if hp_Dict['Use_Mixed_Precision']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=amp.master_params(self.optimizer),
                max_norm=hp_Dict['Train']['Gradient_Norm'])
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                parameters=self.model.parameters(),
                max_norm=hp_Dict['Train']['Gradient_Norm'])

        self.optimizer.step()
        self.scheduler.step()

        self.steps += 1
        self.tqdm.update(1)

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Train']['Loss/{}'.format(tag)] += loss

    def Train_Epoch(self):
        if any([
                hp_Dict['Train']['Train_Pattern']['Mixup']['Use']
                and self.steps >=
                hp_Dict['Train']['Train_Pattern']['Mixup']['Apply_Delay'],
                hp_Dict['Train']['Train_Pattern']['Back_Translate']['Use']
                and self.steps >= hp_Dict['Train']['Train_Pattern']
            ['Back_Translate']['Apply_Delay']
        ]):
            self.Data_Accumulation()

        for audios, mels, pitches, audio_Singers, mel_Singers, noises in self.dataLoader_Dict[
                'Train']:
            self.Train_Step(audios, mels, pitches, audio_Singers, mel_Singers,
                            noises)

            if self.steps % hp_Dict['Train']['Checkpoint_Save_Interval'] == 0:
                self.Save_Checkpoint()

            if self.steps % hp_Dict['Train']['Logging_Interval'] == 0:
                self.scalar_Dict['Train'] = {
                    tag: loss / hp_Dict['Train']['Logging_Interval']
                    for tag, loss in self.scalar_Dict['Train'].items()
                }
                self.scalar_Dict['Train'][
                    'Learning_Rate'] = self.scheduler.get_last_lr()
                self.writer_Dict['Train'].add_scalar_dict(
                    self.scalar_Dict['Train'], self.steps)
                self.scalar_Dict['Train'] = defaultdict(float)

            if self.steps % hp_Dict['Train']['Evaluation_Interval'] == 0:
                self.Evaluation_Epoch()

            if self.steps % hp_Dict['Train']['Inference_Interval'] == 0:
                self.Inference_Epoch()

            if self.steps >= hp_Dict['Train']['Max_Step']:
                return

        self.epochs += hp_Dict['Train']['Train_Pattern'][
            'Accumulated_Dataset_Epoch']

    @torch.no_grad()
    def Evaluation_Step(self, audios, mels, pitches, audio_Singers,
                        mel_Singers, noises):
        loss_Dict = {}

        audios = audios.to(device, non_blocking=True)
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        audio_Singers = audio_Singers.to(device, non_blocking=True)
        mel_Singers = mel_Singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        # Generator
        fakes, predicted_Singers, predicted_Pitches, fakes_Discriminations, reals_Discriminations = self.model(
            mels=mels,
            pitches=pitches,
            singers=audio_Singers,
            noises=noises,
            discrimination=self.steps >=
            hp_Dict['Train']['Discriminator_Delay'],
            reals=audios)

        loss_Dict['Generator/Spectral_Convergence'], loss_Dict[
            'Generator/Magnitude'] = self.criterion_Dict['STFT'](fakes, audios)
        loss_Dict['Generator'] = loss_Dict[
            'Generator/Spectral_Convergence'] + loss_Dict['Generator/Magnitude']

        loss_Dict['Confuser/Singer'] = self.criterion_Dict['CE'](
            predicted_Singers, mel_Singers)
        loss_Dict['Confuser/Pitch'] = self.criterion_Dict['MAE'](
            predicted_Pitches, pitches)
        loss_Dict['Confuser'] = loss_Dict['Confuser/Singer'] + loss_Dict[
            'Confuser/Pitch']
        loss = loss_Dict['Generator'] + loss_Dict['Confuser']

        if self.steps >= hp_Dict['Train']['Discriminator_Delay']:
            loss_Dict['Discriminator/Fake'] = self.criterion_Dict['MSE'](
                fakes_Discriminations, torch.zeros_like(fakes_Discriminations)
            )  # Discriminator thinks that 0 is correct.
            loss_Dict['Discriminator/Real'] = self.criterion_Dict['MSE'](
                reals_Discriminations, torch.ones_like(reals_Discriminations)
            )  # Discriminator thinks that 1 is correct.

            loss_Dict['Discriminator'] = loss_Dict[
                'Discriminator/Fake'] + loss_Dict['Discriminator/Real']
            loss += loss_Dict['Discriminator']

        for tag, loss in loss_Dict.items():
            self.scalar_Dict['Evaluation']['Loss/{}'.format(tag)] += loss

        return fakes, predicted_Singers, predicted_Pitches

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation.'.format(self.steps))

        self.model.eval()

        for step, (audios, mels, pitches, audio_Singers, mel_Singers,
                   noises) in tqdm(
                       enumerate(self.dataLoader_Dict['Dev'], 1),
                       desc='[Evaluation]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Dev'].dataset) /
                           hp_Dict['Train']['Batch_Size'])):
            fakes, predicted_Singers, predicted_Pitches = self.Evaluation_Step(
                audios, mels, pitches, audio_Singers, mel_Singers, noises)

        self.scalar_Dict['Evaluation'] = {
            tag: loss / step
            for tag, loss in self.scalar_Dict['Evaluation'].items()
        }
        self.writer_Dict['Evaluation'].add_scalar_dict(
            self.scalar_Dict['Evaluation'], self.steps)
        self.writer_Dict['Evaluation'].add_histogram_model(
            self.model,
            self.steps,
            delete_keywords=['layer_Dict', '1', 'layer'])
        self.scalar_Dict['Evaluation'] = defaultdict(float)

        self.writer_Dict['Evaluation'].add_image_dict(
            {
                'Mel': (mels[-1].cpu().numpy(), None),
                'Audio/Original': (audios[-1].cpu().numpy(), None),
                'Audio/Predicted': (fakes[-1].cpu().numpy(), None),
                'Pitch/Original': (pitches[-1].cpu().numpy(), None),
                'Pitch/Predicted': (predicted_Pitches[-1].cpu().numpy(), None),
                'Singer/Original': (torch.nn.functional.one_hot(
                    mel_Singers, hp_Dict['Num_Singers']).cpu().numpy(), None),
                'Singer/Predicted':
                (torch.softmax(predicted_Singers, dim=-1).cpu().numpy(), None),
            }, self.steps)

        self.model.train()

    @torch.no_grad()
    def Inference_Step(self,
                       audios,
                       mels,
                       pitches,
                       singers,
                       noises,
                       source_Labels,
                       singer_Labels,
                       start_Index=0,
                       tag_step=False,
                       tag_Index=False):
        audios = audios.to(device, non_blocking=True)
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        singers = singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        fakes, *_ = self.model(mels=mels,
                               pitches=pitches,
                               singers=singers,
                               noises=noises)

        files = []
        for index, (source_Label, singer_Label) in enumerate(
                zip(source_Labels, singer_Labels)):
            tags = []
            if tag_step: tags.append('Step-{}'.format(self.steps))
            tags.append('{}_to_{}'.format(source_Label, singer_Label))
            if tag_Index: tags.append('IDX_{}'.format(index + start_Index))
            files.append('.'.join(tags))

        os.makedirs(os.path.join(hp_Dict['Inference_Path'],
                                 'Step-{}'.format(self.steps),
                                 'PNG').replace('\\', '/'),
                    exist_ok=True)
        os.makedirs(os.path.join(hp_Dict['Inference_Path'],
                                 'Step-{}'.format(self.steps),
                                 'WAV').replace("\\", "/"),
                    exist_ok=True)
        for index, (real, fake, source_Label, singer_Label, file) in enumerate(
                zip(audios.cpu().numpy(),
                    fakes.cpu().numpy(), source_Labels, singer_Labels, files)):
            new_Figure = plt.figure(figsize=(80, 10 * 2), dpi=100)
            plt.subplot(211)
            plt.plot(real)
            plt.title('Original wav    Index: {}    {} -> {}'.format(
                index + start_Index, source_Label, singer_Label))
            plt.subplot(212)
            plt.plot(fake)
            plt.title('Fake wav    Index: {}    {} -> {}'.format(
                index + start_Index, source_Label, singer_Label))
            plt.tight_layout()
            plt.savefig(
                os.path.join(hp_Dict['Inference_Path'],
                             'Step-{}'.format(self.steps), 'PNG',
                             '{}.png'.format(file)).replace('\\', '/'))
            plt.close(new_Figure)

            wavfile.write(filename=os.path.join(
                hp_Dict['Inference_Path'], 'Step-{}'.format(self.steps), 'WAV',
                '{}.wav'.format(file)).replace('\\', '/'),
                          data=(np.clip(fake, -1.0 + 1e-7, 1.0 - 1e-7) *
                                32767.5).astype(np.int16),
                          rate=hp_Dict['Sound']['Sample_Rate'])

    def Inference_Epoch(self):
        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        self.model.eval()

        for step, (audios, mels, pitches, singers, noises, source_Labels,
                   singer_Labels) in tqdm(
                       enumerate(self.dataLoader_Dict['Inference'], 1),
                       desc='[Inference]',
                       total=math.ceil(
                           len(self.dataLoader_Dict['Inference'].dataset) /
                           (hp_Dict['Inference_Batch_Size']
                            or hp_Dict['Train']['Batch_Size']))):
            self.Inference_Step(audios,
                                mels,
                                pitches,
                                singers,
                                noises,
                                source_Labels,
                                singer_Labels,
                                start_Index=(step - 1) *
                                hp_Dict['Train']['Batch_Size'])

        self.model.train()

    @torch.no_grad()
    def Back_Translate_Step(self, mels, pitches, singers, noises):
        mels = mels.to(device, non_blocking=True)
        pitches = pitches.to(device, non_blocking=True)
        singers = singers.to(device, non_blocking=True)
        noises = noises.to(device, non_blocking=True)

        fakes, *_ = self.model(mels=mels,
                               pitches=pitches,
                               singers=singers,
                               noises=noises)

        return fakes.cpu().numpy()

    def Data_Accumulation(self):
        def Mixup(audio, pitch):
            max_Offset = pitch.shape[0] - hp_Dict['Train'][
                'Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2
            offset1 = np.random.randint(0, max_Offset)
            offset2 = np.random.randint(0, max_Offset)
            beta = np.random.uniform(
                low=hp_Dict['Train']['Train_Pattern']['Mixup']['Min_Beta'],
                high=hp_Dict['Train']['Train_Pattern']['Mixup']['Max_Beta'],
            )

            new_Audio = \
                beta * audio[offset1 * hp_Dict['Sound']['Frame_Shift']:offset1 * hp_Dict['Sound']['Frame_Shift'] + hp_Dict['Train']['Wav_Length'] * 2] + \
                (1 - beta) * audio[offset2* hp_Dict['Sound']['Frame_Shift']:offset2 * hp_Dict['Sound']['Frame_Shift'] + hp_Dict['Train']['Wav_Length'] * 2]

            new_Pitch = \
                beta * pitch[offset1:offset1 + hp_Dict['Train']['Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2] + \
                (1 - beta) * pitch[offset2:offset2 + hp_Dict['Train']['Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2]

            _, new_Mel, _, _ = Pattern_Generate(audio=new_Audio)

            return new_Audio, new_Mel, new_Pitch

        def Back_Translate(mels, pitches, singers, noises):
            fakes = self.Back_Translate_Step(mels=mels,
                                             pitches=pitches,
                                             singers=singers,
                                             noises=noises)

            new_Mels = [Pattern_Generate(audio=fake)[1] for fake in fakes]

            return new_Mels

        print()
        mixup_List = []
        back_Translate_List = []
        for total_Audios, total_Pitches, audios, mels, pitches, singers, noises in tqdm(
                self.dataLoader_Dict['Accumulation'],
                desc='[Accumulation]',
                total=math.ceil(
                    len(self.dataLoader_Dict['Accumulation'].dataset) /
                    hp_Dict['Train']['Batch_Size'])):
            #Mixup
            if hp_Dict['Train']['Train_Pattern']['Mixup'][
                    'Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'][
                        'Mixup']['Apply_Delay']:
                for audio, pitch, singer in zip(total_Audios, total_Pitches,
                                                singers.numpy()):
                    mixup_Audio, mixup_Mel, mixup_Pitch = Mixup(audio, pitch)
                    mixup_List.append(
                        (mixup_Audio, mixup_Mel, mixup_Pitch, singer, singer))

            #Backtranslate
            if hp_Dict['Train']['Train_Pattern']['Back_Translate'][
                    'Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'][
                        'Back_Translate']['Apply_Delay']:
                mel_Singers = torch.LongTensor(
                    np.stack([
                        choice([
                            x for x in range(hp_Dict['Num_Singers'])
                            if x != singer
                        ]) for singer in singers
                    ],
                             axis=0))
                back_Translate_Mels = Back_Translate(mels, pitches,
                                                     mel_Singers, noises)
                for audio, back_Translate_Mel, pitch, audio_Singer, mel_Singer in zip(
                        audios.numpy(), back_Translate_Mels, pitches.numpy(),
                        singers.numpy(), mel_Singers.numpy()):
                    back_Translate_List.append(
                        (audio, back_Translate_Mel, pitch, audio_Singer,
                         mel_Singer))

        self.dataLoader_Dict['Train'].dataset.Accumulation_Renew(
            mixup_Pattern_List=mixup_List,
            back_Translate_Pattern_List=back_Translate_List)

    def Load_Checkpoint(self):
        if self.steps == 0:
            paths = [
                os.path.join(root, file).replace('\\', '/')
                for root, _, files in os.walk(hp_Dict['Checkpoint_Path'])
                for file in files if os.path.splitext(file)[1] == '.pt'
            ]
            if len(paths) > 0:
                path = max(paths, key=os.path.getctime)
            else:
                return  # Initial training
        else:
            path = os.path.join(
                hp_Dict['Checkpoint_Path'],
                'S_{}.pt'.format(self.steps).replace('\\', '/'))

        state_Dict = torch.load(path, map_location='cpu')
        self.model.load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']
        self.epochs = state_Dict['Epochs']

        if hp_Dict['Use_Mixed_Precision']:
            if not 'AMP' in state_Dict.keys():
                logging.info(
                    'No AMP state dict is in the checkpoint. Model regards this checkpoint is trained without mixed precision.'
                )
            else:
                amp.load_state_dict(state_Dict['AMP'])

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

    def Save_Checkpoint(self):
        os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True)

        state_Dict = {
            'Model': self.model.state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps,
            'Epochs': self.epochs,
        }
        if hp_Dict['Use_Mixed_Precision']:
            state_Dict['AMP'] = amp.state_dict()

        torch.save(
            state_Dict,
            os.path.join(hp_Dict['Checkpoint_Path'],
                         'S_{}.pt'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def Train(self):
        if not os.path.exists(
                os.path.join(hp_Dict['Checkpoint_Path'],
                             'Hyper_Parameters.yaml')):
            os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True)
            with open(
                    os.path.join(hp_Dict['Checkpoint_Path'],
                                 'Hyper_Parameters.yaml').replace("\\", "/"),
                    "w") as f:
                yaml.dump(hp_Dict, f)

        if self.steps == 0:
            self.Evaluation_Epoch()

        if hp_Dict['Train']['Initial_Inference']:
            self.Inference_Epoch()

        self.tqdm = tqdm(initial=self.steps,
                         total=hp_Dict['Train']['Max_Step'],
                         desc='[Training]')

        while self.steps < hp_Dict['Train']['Max_Step']:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')
Exemple #5
0
class Trainer:
    def __init__(self, hp_path, steps=0):
        self.hp_path = hp_path
        self.gpu_id = int(os.getenv('RANK', '0'))
        self.num_gpus = int(os.getenv("WORLD_SIZE", '1'))

        self.hp = Recursive_Parse(
            yaml.load(open(hp_path, encoding='utf-8'), Loader=yaml.Loader))

        if not torch.cuda.is_available():
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:{}'.format(self.gpu_id))
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True
            torch.cuda.set_device(self.gpu_id)

        self.steps = steps

        self.Dataset_Generate()
        self.Model_Generate()
        self.Load_Checkpoint()
        self._Set_Distribution()

        self.scalar_dict = {
            'Train': defaultdict(float),
            'Evaluation': defaultdict(float),
        }

        self.writer_dict = {
            'Train': Logger(os.path.join(self.hp.Log_Path, 'Train')),
            'Evaluation': Logger(os.path.join(self.hp.Log_Path, 'Evaluation')),
        }

    def Dataset_Generate(self):
        train_dataset = Dataset(
            pattern_path=self.hp.Train.Train_Pattern.Path,
            metadata_file=self.hp.Train.Train_Pattern.Metadata_File,
            pattern_per_speaker=self.hp.Train.Batch.Train.Pattern_per_Speaker)
        dev_dataset = Dataset(
            pattern_path=self.hp.Train.Eval_Pattern.Path,
            metadata_file=self.hp.Train.Eval_Pattern.Metadata_File,
            pattern_per_speaker=self.hp.Train.Batch.Eval.Pattern_per_Speaker)
        inference_dataset = Dataset(
            pattern_path=self.hp.Train.Eval_Pattern.Path,
            metadata_file=self.hp.Train.Eval_Pattern.Metadata_File,
            pattern_per_speaker=self.hp.Train.Batch.Eval.Pattern_per_Speaker,
            num_speakers=50,  #Maximum number by tensorboard.
        )
        logging.info('The number of train speakers = {}.'.format(
            len(train_dataset)))
        logging.info('The number of development speakers = {}.'.format(
            len(dev_dataset)))

        collater = Collater(min_frame_length=self.hp.Train.Frame_Length.Min,
                            max_frame_length=self.hp.Train.Frame_Length.Max)
        inference_collater = Inference_Collater(
            samples=self.hp.Train.Inference.Samples,
            frame_length=self.hp.Train.Inference.Frame_Length,
            overlap_length=self.hp.Train.Inference.Overlap_Length)

        self.dataloader_dict = {}
        self.dataloader_dict['Train'] = torch.utils.data.DataLoader(
            dataset= train_dataset,
            sampler= torch.utils.data.DistributedSampler(train_dataset, shuffle= True) \
                     if self.hp.Use_Multi_GPU else \
                     torch.utils.data.RandomSampler(train_dataset),
            collate_fn= collater,
            batch_size= self.hp.Train.Batch.Train.Speaker,
            num_workers= self.hp.Train.Num_Workers,
            pin_memory= True
            )
        self.dataloader_dict['Dev'] = torch.utils.data.DataLoader(
            dataset= dev_dataset,
            sampler= torch.utils.data.DistributedSampler(dev_dataset, shuffle= True) \
                     if self.num_gpus > 1 else \
                     torch.utils.data.RandomSampler(dev_dataset),
            collate_fn= collater,
            batch_size= self.hp.Train.Batch.Eval.Speaker,
            num_workers= self.hp.Train.Num_Workers,
            pin_memory= True
            )
        self.dataloader_dict['Inference'] = torch.utils.data.DataLoader(
            dataset=inference_dataset,
            shuffle=True,
            collate_fn=inference_collater,
            batch_size=self.hp.Train.Batch.Eval.Speaker,
            num_workers=self.hp.Train.Num_Workers,
            pin_memory=True)

    def Model_Generate(self):
        self.model = GE2E(self.hp).to(self.device)
        self.criterion = GE2E_Loss().to(self.device)
        self.optimizer = RAdam(params=self.model.parameters(),
                               lr=self.hp.Train.Learning_Rate.Initial,
                               betas=(self.hp.Train.ADAM.Beta1,
                                      self.hp.Train.ADAM.Beta2),
                               eps=self.hp.Train.ADAM.Epsilon,
                               weight_decay=self.hp.Train.Weight_Decay)
        self.scheduler = Modified_Noam_Scheduler(
            optimizer=self.optimizer,
            base=self.hp.Train.Learning_Rate.Base,
        )

        self.scaler = torch.cuda.amp.GradScaler(
            enabled=self.hp.Use_Mixed_Precision)

        if self.gpu_id == 0:
            logging.info(self.model)

    def Train_Step(self, features):
        loss_dict = {}

        features = features.to(self.device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=self.hp.Use_Mixed_Precision):
            embeddings = self.model(features)
            loss_dict['Embedding'] = self.criterion(
                embeddings, self.hp.Train.Batch.Train.Pattern_per_Speaker)

        self.optimizer.zero_grad()
        self.scaler.scale(loss_dict['Embedding']).backward()
        if self.hp.Train.Gradient_Norm > 0.0:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(
                parameters=self.model.parameters(),
                max_norm=self.hp.Train.Gradient_Norm)

        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.scheduler.step()
        self.steps += 1
        self.tqdm.update(1)

        for tag, loss in loss_dict.items():
            loss = reduce_tensor(
                loss.data,
                self.num_gpus).item() if self.num_gpus > 1 else loss.item()
            self.scalar_dict['Train']['Loss/{}'.format(tag)] += loss

    def Train_Epoch(self):
        for features in self.dataloader_dict['Train']:
            self.Train_Step(features)

            if self.steps % self.hp.Train.Checkpoint_Save_Interval == 0:
                self.Save_Checkpoint()

            if self.steps % self.hp.Train.Logging_Interval == 0:
                self.scalar_dict['Train'] = {
                    tag: loss / self.hp.Train.Logging_Interval
                    for tag, loss in self.scalar_dict['Train'].items()
                }
                self.scalar_dict['Train'][
                    'Learning_Rate'] = self.scheduler.get_last_lr()
                self.writer_dict['Train'].add_scalar_dict(
                    self.scalar_dict['Train'], self.steps)
                self.scalar_dict['Train'] = defaultdict(float)

            if self.steps % self.hp.Train.Evaluation_Interval == 0:
                self.Evaluation_Epoch()

            if self.steps % self.hp.Train.Inference_Interval == 0:
                self.Inference_Epoch()

            if self.steps >= self.hp.Train.Max_Step:
                return

    @torch.no_grad()
    def Evaluation_Step(self, features):
        loss_dict = {}

        features = features.to(self.device, non_blocking=True)

        embeddings = self.model(features)
        loss_dict['Embedding'] = self.criterion(
            embeddings, self.hp.Train.Batch.Eval.Pattern_per_Speaker)

        for tag, loss in loss_dict.items():
            loss = reduce_tensor(
                loss.data,
                self.num_gpus).item() if self.num_gpus > 1 else loss.item()
            self.scalar_dict['Evaluation']['Loss/{}'.format(tag)] += loss

    def Evaluation_Epoch(self):
        logging.info('(Steps: {}) Start evaluation in GPU {}.'.format(
            self.steps, self.gpu_id))

        self.model.eval()

        for step, features in tqdm(
                enumerate(self.dataloader_dict['Dev'], 1),
                desc='[Evaluation]',
                total=math.ceil(
                    len(self.dataloader_dict['Dev'].dataset) /
                    self.hp.Train.Batch.Eval.Speaker /
                    self.hp.Train.Batch.Eval.Pattern_per_Speaker)):
            self.Evaluation_Step(features)

        self.scalar_dict['Evaluation'] = {
            tag: loss / step
            for tag, loss in self.scalar_dict['Evaluation'].items()
        }
        self.writer_dict['Evaluation'].add_scalar_dict(
            self.scalar_dict['Evaluation'], self.steps)
        self.writer_dict['Evaluation'].add_histogram_model(
            self.model,
            'GE2E',
            self.steps,
            delete_keywords=['layer_Dict', 'layer'])
        self.scalar_dict['Evaluation'] = defaultdict(float)

        self.model.train()

    @torch.no_grad()
    def Inference_Step(self, features):
        return self.model(features=features.to(self.device, non_blocking=True),
                          samples=self.hp.Train.Inference.Samples)

    def Inference_Epoch(self):
        if self.gpu_id != 0:
            return

        logging.info('(Steps: {}) Start inference.'.format(self.steps))

        self.model.eval()

        embeddings, speakers = zip(
            *[(self.Inference_Step(features), speakers)
              for features, speakers in tqdm(self.dataloader_dict['Inference'],
                                             desc='[Inference]')])
        embeddings = torch.cat(embeddings, dim=0).cpu().numpy()
        speakers = [
            speaker for speaker_list in speakers for speaker in speaker_list
        ]

        self.writer_dict['Evaluation'].add_embedding(embeddings,
                                                     metadata=speakers,
                                                     global_step=self.steps,
                                                     tag='Embeddings')

        self.model.train()

    def Load_Checkpoint(self):
        if self.steps == 0:
            paths = [
                os.path.join(root, file).replace('\\', '/')
                for root, _, files in os.walk(self.hp.Checkpoint_Path)
                for file in files if os.path.splitext(file)[1] == '.pt'
            ]
            if len(paths) > 0:
                path = max(paths, key=os.path.getctime)
            else:
                return  # Initial training
        else:
            path = os.path.join(
                self.hp.Checkpoint_Path,
                'S_{}.pt'.format(self.steps).replace('\\', '/'))

        state_Dict = torch.load(os.path.join(path), map_location='cpu')
        self.model.load_state_dict(state_Dict['Model'])
        self.optimizer.load_state_dict(state_Dict['Optimizer'])
        self.scheduler.load_state_dict(state_Dict['Scheduler'])
        self.steps = state_Dict['Steps']

        logging.info('Checkpoint loaded at {} steps.'.format(self.steps))

    def Save_Checkpoint(self):
        if self.gpu_id != 0:
            return

        os.makedirs(self.hp.Checkpoint_Path, exist_ok=True)

        state_Dict = {
            'Model': self.model.state_dict(),
            'Optimizer': self.optimizer.state_dict(),
            'Scheduler': self.scheduler.state_dict(),
            'Steps': self.steps
        }

        torch.save(
            state_Dict,
            os.path.join(self.hp.Checkpoint_Path,
                         'S_{}.pt'.format(self.steps).replace('\\', '/')))

        logging.info('Checkpoint saved at {} steps.'.format(self.steps))

    def _Set_Distribution(self):
        if self.num_gpus > 1:
            self.model = apply_gradient_allreduce(self.model)

    def Train(self):
        hp_path = os.path.join(self.hp.Checkpoint_Path,
                               'Hyper_Parameters.yaml').replace('\\', '/')
        if not os.path.exists(hp_path):
            os.makedirs(self.hp.Checkpoint_Path, exist_ok=True)
            yaml.dump(self.hp, open(hp_path, 'w'))

        if self.steps == 0:
            self.Evaluation_Epoch()

        if self.hp.Train.Initial_Inference:
            self.Inference_Epoch()

        self.tqdm = tqdm(initial=self.steps,
                         total=self.hp.Train.Max_Step,
                         desc='[Training]')

        while self.steps < self.hp.Train.Max_Step:
            try:
                self.Train_Epoch()
            except KeyboardInterrupt:
                self.Save_Checkpoint()
                exit(1)

        self.tqdm.close()
        logging.info('Finished training.')