def Model_Generate(self): self.model_Dict = {'GlowTTS': GlowTTS().to(device)} if not hp.Speaker_Embedding.GE2E.Checkpoint_Path is None: self.model_Dict['Speaker_Embedding'] = Speaker_Embedding( mel_dims=hp.Sound.Mel_Dim, lstm_size=hp.Speaker_Embedding.GE2E.LSTM.Sizes, lstm_stacks=hp.Speaker_Embedding.GE2E.LSTM.Stacks, embedding_size=hp.Speaker_Embedding.Embedding_Size, ).to(device) self.criterion_Dict = { 'MSE': torch.nn.MSELoss().to(device), 'MLE': MLE_Loss().to(device), 'CE': torch.nn.CrossEntropyLoss().to(device) } self.optimizer = RAdam(params=self.model_Dict['GlowTTS'].parameters(), lr=hp.Train.Learning_Rate.Initial, betas=(hp.Train.ADAM.Beta1, hp.Train.ADAM.Beta2), eps=hp.Train.ADAM.Epsilon, weight_decay=hp.Train.Weight_Decay) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=hp.Train.Learning_Rate.Base) if hp.Use_Mixed_Precision: self.model_Dict['GlowTTS'], self.optimizer = amp.initialize( models=self.model_Dict['GlowTTS'], optimizers=self.optimizer) logging.info(self.model_Dict['GlowTTS'])
def Model_Generate(self): self.model = PVCGAN().to(device) self.criterion_Dict = { 'STFT': MultiResolutionSTFTLoss( fft_sizes=hp_Dict['STFT_Loss_Resolution']['FFT_Sizes'], shift_lengths=hp_Dict['STFT_Loss_Resolution']['Shfit_Lengths'], win_lengths=hp_Dict['STFT_Loss_Resolution']['Win_Lengths'], ).to(device), 'MSE': torch.nn.MSELoss().to(device), 'CE': torch.nn.CrossEntropyLoss().to(device), 'MAE': torch.nn.L1Loss().to(device) } self.optimizer = RAdam(params=self.model.parameters(), lr=hp_Dict['Train']['Learning_Rate']['Initial'], betas=(hp_Dict['Train']['ADAM']['Beta1'], hp_Dict['Train']['ADAM']['Beta2']), eps=hp_Dict['Train']['ADAM']['Epsilon'], weight_decay=hp_Dict['Train']['Weight_Decay']) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=hp_Dict['Train']['Learning_Rate']['Base']) if hp_Dict['Use_Mixed_Precision']: self.model, self.optimizer = amp.initialize( models=self.model, optimizers=self.optimizer) logging.info(self.model)
def Model_Generate(self): self.model = GE2E(self.hp).to(self.device) self.criterion = GE2E_Loss().to(self.device) self.optimizer = RAdam(params=self.model.parameters(), lr=self.hp.Train.Learning_Rate.Initial, betas=(self.hp.Train.ADAM.Beta1, self.hp.Train.ADAM.Beta2), eps=self.hp.Train.ADAM.Epsilon, weight_decay=self.hp.Train.Weight_Decay) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=self.hp.Train.Learning_Rate.Base, ) self.scaler = torch.cuda.amp.GradScaler( enabled=self.hp.Use_Mixed_Precision) if self.gpu_id == 0: logging.info(self.model)
class Trainer: def __init__(self, steps=0): self.steps = steps self.epochs = 0 self.Datset_Generate() self.Model_Generate() self.scalar_Dict = { 'Train': defaultdict(float), 'Evaluation': defaultdict(float), } self.writer_Dict = { 'Train': Logger(os.path.join(hp.Log_Path, 'Train')), 'Evaluation': Logger(os.path.join(hp.Log_Path, 'Evaluation')), } self.Load_Checkpoint() def Datset_Generate(self): train_Dataset = Dataset( pattern_path=hp.Train.Train_Pattern.Path, metadata_file=hp.Train.Train_Pattern.Metadata_File, accumulated_dataset_epoch=hp.Train.Train_Pattern. Accumulated_Dataset_Epoch, mel_length_min=hp.Train.Train_Pattern.Mel_Length.Min, mel_length_max=hp.Train.Train_Pattern.Mel_Length.Max, text_length_min=hp.Train.Train_Pattern.Text_Length.Min, text_length_max=hp.Train.Train_Pattern.Text_Length.Max, use_cache=hp.Train.Use_Pattern_Cache) dev_Dataset = Dataset( pattern_path=hp.Train.Eval_Pattern.Path, metadata_file=hp.Train.Eval_Pattern.Metadata_File, mel_length_min=hp.Train.Eval_Pattern.Mel_Length.Min, mel_length_max=hp.Train.Eval_Pattern.Mel_Length.Max, text_length_min=hp.Train.Eval_Pattern.Text_Length.Min, text_length_max=hp.Train.Eval_Pattern.Text_Length.Max, use_cache=hp.Train.Use_Pattern_Cache) inference_Dataset = Inference_Dataset( pattern_path=hp.Train.Inference_Pattern_File_in_Train) logging.info('The number of train patterns = {}.'.format( len(train_Dataset) // hp.Train.Train_Pattern.Accumulated_Dataset_Epoch)) logging.info('The number of development patterns = {}.'.format( len(dev_Dataset))) logging.info('The number of inference patterns = {}.'.format( len(inference_Dataset))) collater = Collater() inference_Collater = Inference_Collater() self.dataLoader_Dict = {} self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader( dataset=train_Dataset, shuffle=True, collate_fn=collater, batch_size=hp.Train.Batch_Size, num_workers=hp.Train.Num_Workers, pin_memory=True) self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader( dataset=dev_Dataset, shuffle=True, collate_fn=collater, batch_size=hp.Train.Batch_Size, num_workers=hp.Train.Num_Workers, pin_memory=True) self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader( dataset=inference_Dataset, shuffle=False, collate_fn=inference_Collater, batch_size=hp.Inference_Batch_Size or hp.Train.Batch_Size, num_workers=hp.Train.Num_Workers, pin_memory=True) if hp.Mode in ['PE', 'GR']: self.dataLoader_Dict[ 'Prosody_Check'] = torch.utils.data.DataLoader( dataset=Prosody_Check_Dataset( pattern_path=hp.Train.Train_Pattern.Path, metadata_file=hp.Train.Train_Pattern.Metadata_File, mel_length_min=hp.Train.Train_Pattern.Mel_Length.Min, mel_length_max=hp.Train.Train_Pattern.Mel_Length.Max, use_cache=hp.Train.Use_Pattern_Cache), shuffle=False, collate_fn=Prosody_Check_Collater(), batch_size=hp.Train.Batch_Size, num_workers=hp.Train.Num_Workers, pin_memory=True) def Model_Generate(self): self.model_Dict = {'GlowTTS': GlowTTS().to(device)} if not hp.Speaker_Embedding.GE2E.Checkpoint_Path is None: self.model_Dict['Speaker_Embedding'] = Speaker_Embedding( mel_dims=hp.Sound.Mel_Dim, lstm_size=hp.Speaker_Embedding.GE2E.LSTM.Sizes, lstm_stacks=hp.Speaker_Embedding.GE2E.LSTM.Stacks, embedding_size=hp.Speaker_Embedding.Embedding_Size, ).to(device) self.criterion_Dict = { 'MSE': torch.nn.MSELoss().to(device), 'MLE': MLE_Loss().to(device), 'CE': torch.nn.CrossEntropyLoss().to(device) } self.optimizer = RAdam(params=self.model_Dict['GlowTTS'].parameters(), lr=hp.Train.Learning_Rate.Initial, betas=(hp.Train.ADAM.Beta1, hp.Train.ADAM.Beta2), eps=hp.Train.ADAM.Epsilon, weight_decay=hp.Train.Weight_Decay) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=hp.Train.Learning_Rate.Base) if hp.Use_Mixed_Precision: self.model_Dict['GlowTTS'], self.optimizer = amp.initialize( models=self.model_Dict['GlowTTS'], optimizers=self.optimizer) logging.info(self.model_Dict['GlowTTS']) def Train_Step(self, tokens, token_lengths, mels, mel_lengths, speakers, mels_for_ge2e, pitches): loss_Dict = {} tokens = tokens.to(device) token_lengths = token_lengths.to(device) mels = mels.to(device) mel_lengths = mel_lengths.to(device) speakers = speakers.to(device) mels_for_ge2e = mels_for_ge2e.to(device) pitches = pitches.to(device) z, mel_Mean, mel_Log_Std, log_Dets, log_Durations, log_Duration_Targets, _, classified_Speakers = self.model_Dict[ 'GlowTTS'](tokens=tokens, token_lengths=token_lengths, mels=mels, mel_lengths=mel_lengths, speakers=speakers, mels_for_ge2e=mels_for_ge2e, pitches=pitches) loss_Dict['MLE'] = self.criterion_Dict['MLE'](z=z, mean=mel_Mean, std=mel_Log_Std, log_dets=log_Dets, lengths=mel_lengths) loss_Dict['Length'] = self.criterion_Dict['MSE'](log_Durations, log_Duration_Targets) loss_Dict['Total'] = loss_Dict['MLE'] + loss_Dict['Length'] loss = loss_Dict['Total'] if not classified_Speakers is None: loss_Dict['Speaker'] = self.criterion_Dict['CE']( classified_Speakers, speakers) loss = loss_Dict['Total'] + loss_Dict['Speaker'] self.optimizer.zero_grad() if hp.Use_Mixed_Precision: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(parameters=amp.master_params( self.optimizer), max_norm=hp.Train.Gradient_Norm) else: loss.backward() torch.nn.utils.clip_grad_norm_( parameters=self.model_Dict['GlowTTS'].parameters(), max_norm=hp.Train.Gradient_Norm) self.optimizer.step() self.scheduler.step() self.steps += 1 self.tqdm.update(1) for tag, loss in loss_Dict.items(): self.scalar_Dict['Train']['Loss/{}'.format(tag)] += loss def Train_Epoch(self): for tokens, token_Lengths, mels, mel_Lengths, speakers, mels_for_GE2E, pitches in self.dataLoader_Dict[ 'Train']: self.Train_Step(tokens, token_Lengths, mels, mel_Lengths, speakers, mels_for_GE2E, pitches) if self.steps % hp.Train.Checkpoint_Save_Interval == 0: self.Save_Checkpoint() if self.steps % hp.Train.Logging_Interval == 0: self.scalar_Dict['Train'] = { tag: loss / hp.Train.Logging_Interval for tag, loss in self.scalar_Dict['Train'].items() } self.scalar_Dict['Train'][ 'Learning_Rate'] = self.scheduler.get_last_lr() self.writer_Dict['Train'].add_scalar_dict( self.scalar_Dict['Train'], self.steps) self.scalar_Dict['Train'] = defaultdict(float) if self.steps % hp.Train.Evaluation_Interval == 0: self.Evaluation_Epoch() if self.steps % hp.Train.Inference_Interval == 0: self.Inference_Epoch() if self.steps >= hp.Train.Max_Step: return self.epochs += hp.Train.Train_Pattern.Accumulated_Dataset_Epoch @torch.no_grad() def Evaluation_Step(self, tokens, token_lengths, mels, mel_lengths, speakers, mels_for_ge2e, pitches): loss_Dict = {} tokens = tokens.to(device) token_lengths = token_lengths.to(device) mels = mels.to(device) mel_lengths = mel_lengths.to(device) speakers = speakers.to(device) mels_for_ge2e = mels_for_ge2e.to(device) pitches = pitches.to(device) z, mel_Mean, mel_Log_Std, log_Dets, log_Durations, log_Duration_Targets, attentions_from_Train, classified_Speakers = self.model_Dict[ 'GlowTTS'](tokens=tokens, token_lengths=token_lengths, mels=mels, mel_lengths=mel_lengths, speakers=speakers, mels_for_ge2e=mels_for_ge2e, pitches=pitches) loss_Dict['MLE'] = self.criterion_Dict['MLE'](z=z, mean=mel_Mean, std=mel_Log_Std, log_dets=log_Dets, lengths=mel_lengths) loss_Dict['Length'] = self.criterion_Dict['MSE'](log_Durations, log_Duration_Targets) loss_Dict['Total'] = loss_Dict['MLE'] + loss_Dict['Length'] if not classified_Speakers is None: loss_Dict['Speaker'] = self.criterion_Dict['CE']( classified_Speakers, speakers) for tag, loss in loss_Dict.items(): self.scalar_Dict['Evaluation']['Loss/{}'.format(tag)] += loss # For tensorboard images mels, _, attentions_from_Inference = self.model_Dict[ 'GlowTTS'].inference(tokens=tokens, token_lengths=token_lengths, mels_for_prosody=mels, mel_lengths_for_prosody=mel_lengths, speakers=speakers, mels_for_ge2e=mels_for_ge2e, pitches=pitches, pitch_lengths=mel_lengths, length_scale=torch.FloatTensor([1.0 ]).to(device)) return mels, attentions_from_Train, attentions_from_Inference, classified_Speakers def Evaluation_Epoch(self): logging.info('(Steps: {}) Start evaluation.'.format(self.steps)) for model in self.model_Dict.values(): model.eval() for step, (tokens, token_Lengths, mels, mel_Lengths, speakers, mels_for_GE2E, pitches) in tqdm( enumerate(self.dataLoader_Dict['Dev'], 1), desc='[Evaluation]', total=math.ceil( len(self.dataLoader_Dict['Dev'].dataset) / hp.Train.Batch_Size)): mel_Predictions, attentions_from_Train, attentions_from_Inference, classified_Speakers = self.Evaluation_Step( tokens, token_Lengths, mels, mel_Lengths, speakers, mels_for_GE2E, pitches) self.scalar_Dict['Evaluation'] = { tag: loss / step for tag, loss in self.scalar_Dict['Evaluation'].items() } self.writer_Dict['Evaluation'].add_scalar_dict( self.scalar_Dict['Evaluation'], self.steps) self.writer_Dict['Evaluation'].add_histogram_model( self.model_Dict['GlowTTS'], self.steps, delete_keywords=['layer_Dict', 'layer', 'GE2E']) self.scalar_Dict['Evaluation'] = defaultdict(float) image_Dict = { 'Mel/Target': (mels[-1].cpu().numpy(), None), 'Mel/Prediction': (mel_Predictions[-1].cpu().numpy(), None), 'Attention/From_Train': (attentions_from_Train[-1].cpu().numpy(), None), 'Attention/From_Inference': (attentions_from_Inference[-1].cpu().numpy(), None) } if not classified_Speakers is None: image_Dict.update({ 'Speaker/Original': (torch.nn.functional.one_hot( speakers, hp.Speaker_Embedding.Num_Speakers).cpu().numpy(), None), 'Speaker/Predicted': (torch.softmax(classified_Speakers, dim=-1).cpu().numpy(), None), }) self.writer_Dict['Evaluation'].add_image_dict(image_Dict, self.steps) for model in self.model_Dict.values(): model.train() if hp.Mode in ['PE', 'GR' ] and self.steps % hp.Train.Prosody_Check_Interval == 0: self.Prosody_Check_Epoch() @torch.no_grad() def Inference_Step(self, tokens, token_lengths, mels_for_prosody, mel_lengths_for_prosody, speakers, mels_for_ge2e, pitches, pitch_lengths, length_scales, labels, texts, start_index=0, tag_step=False, tag_index=False): tokens = tokens.to(device) token_lengths = token_lengths.to(device) mels_for_prosody = mels_for_prosody.to(device) mel_lengths_for_prosody = mel_lengths_for_prosody.to(device) speakers = speakers.to(device) mels_for_ge2e = mels_for_ge2e.to(device) pitches = pitches.to(device) length_scales = length_scales.to(device) mels, mel_Lengths, attentions = self.model_Dict['GlowTTS'].inference( tokens=tokens, token_lengths=token_lengths, mels_for_prosody=mels_for_prosody, mel_lengths_for_prosody=mel_lengths_for_prosody, speakers=speakers, mels_for_ge2e=mels_for_ge2e, pitches=pitches, pitch_lengths=pitch_lengths, length_scale=length_scales) files = [] for index, label in enumerate(labels): tags = [] if tag_step: tags.append('Step-{}'.format(self.steps)) tags.append(label) if tag_index: tags.append('IDX_{}'.format(index + start_index)) files.append('.'.join(tags)) os.makedirs(os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'PNG').replace('\\', '/'), exist_ok=True) for index, (mel, mel_Length, attention, label, text, length_Scale, file) in enumerate( zip(mels.cpu().numpy(), mel_Lengths.cpu().numpy(), attentions.cpu().numpy(), labels, texts, length_scales, files)): mel = mel[:, :mel_Length] attention = attention[:len(text) + 2, :mel_Length] new_Figure = plt.figure(figsize=(20, 5 * 3), dpi=100) plt.subplot2grid((3, 1), (0, 0)) plt.imshow(mel, aspect='auto', origin='lower') plt.title( 'Mel Label: {} Text: {} Length scale: {:.3f}'.format( label, text if len(text) < 70 else text[:70] + '…', length_Scale)) plt.colorbar() plt.subplot2grid((3, 1), (1, 0), rowspan=2) plt.imshow(attention, aspect='auto', origin='lower', interpolation='none') plt.title( 'Attention Label: {} Text: {} Length scale: {:.3f}'. format(label, text if len(text) < 70 else text[:70] + '…', length_Scale)) plt.yticks(range(len(text) + 2), ['<S>'] + list(text) + ['<E>'], fontsize=10) plt.colorbar() plt.tight_layout() plt.savefig( os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'PNG', '{}.PNG'.format(file)).replace('\\', '/')) plt.close(new_Figure) os.makedirs(os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'NPY').replace('\\', '/'), exist_ok=True) os.makedirs(os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'NPY', 'Mel').replace('\\', '/'), exist_ok=True) os.makedirs(os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'NPY', 'Attention').replace('\\', '/'), exist_ok=True) for index, (mel, mel_Length, file) in enumerate( zip(mels.cpu().numpy(), mel_Lengths.cpu().numpy(), files)): mel = mel[:, :mel_Length] attention = attention[:len(text) + 2, :mel_Length] np.save(os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'NPY', 'Mel', file).replace('\\', '/'), mel.T, allow_pickle=False) np.save(os.path.join(hp.Inference_Path, 'Step-{}'.format(self.steps), 'NPY', 'Attention', file).replace('\\', '/'), attentions.cpu().numpy()[index], allow_pickle=False) def Inference_Epoch(self): logging.info('(Steps: {}) Start inference.'.format(self.steps)) for model in self.model_Dict.values(): model.eval() for step, (tokens, token_Lengths, mels_for_Prosody, mel_Lengths_for_Prosody, speakers, mels_for_GE2E, pitches, pitch_Lengths, length_Scales, labels, texts) in tqdm( enumerate(self.dataLoader_Dict['Inference']), desc='[Inference]', total=math.ceil( len(self.dataLoader_Dict['Inference'].dataset) / (hp.Inference_Batch_Size or hp.Train.Batch_Size))): self.Inference_Step( tokens, token_Lengths, mels_for_Prosody, mel_Lengths_for_Prosody, speakers, mels_for_GE2E, pitches, pitch_Lengths, length_Scales, labels, texts, start_index=step * (hp.Inference_Batch_Size or hp.Train.Batch_Size)) for model in self.model_Dict.values(): model.train() @torch.no_grad() def Prosody_Check_Step(self, mels, mel_lengths): mels = mels.to(device) mel_lengths = mel_lengths.to(device) prosodies = self.model_Dict['GlowTTS'].layer_Dict['Prosody_Encoder']( mels, mel_lengths) return prosodies def Prosody_Check_Epoch(self): logging.info('(Steps: {}) Start prosody check.'.format(self.steps)) for model in self.model_Dict.values(): model.eval() prosodies, labels = zip( *[(self.Prosody_Check_Step(mels, mel_Lengths), labels) for mels, mel_Lengths, labels in tqdm( self.dataLoader_Dict['Prosody_Check'], desc='[Prosody_Check]', total=math.ceil( len(self.dataLoader_Dict['Prosody_Check'].dataset) / hp.Train.Batch_Size))]) prosodies = torch.cat(prosodies, dim=0) labels = [label for sub_labels in labels for label in sub_labels] self.writer_Dict['Evaluation'].add_embedding(prosodies.cpu().numpy(), metadata=labels, global_step=self.steps, tag='Prosodies') for model in self.model_Dict.values(): model.train() def Load_Checkpoint(self): if self.steps == 0: paths = [ os.path.join(root, file).replace('\\', '/') for root, _, files in os.walk(hp.Checkpoint_Path) for file in files if os.path.splitext(file)[1] == '.pt' ] if len(paths) > 0: path = max(paths, key=os.path.getctime) else: return # Initial training else: path = os.path.join( hp.Checkpoint_Path, 'S_{}.pt'.format(self.steps).replace('\\', '/')) state_Dict = torch.load(path, map_location='cpu') self.model_Dict['GlowTTS'].load_state_dict(state_Dict['Model']) self.optimizer.load_state_dict(state_Dict['Optimizer']) self.scheduler.load_state_dict(state_Dict['Scheduler']) self.steps = state_Dict['Steps'] self.epochs = state_Dict['Epochs'] if hp.Use_Mixed_Precision: if not 'AMP' in state_Dict.keys(): logging.info( 'No AMP state dict is in the checkpoint. Model regards this checkpoint is trained without mixed precision.' ) else: amp.load_state_dict(state_Dict['AMP']) for flow in self.model_Dict['GlowTTS'].layer_Dict[ 'Decoder'].layer_Dict['Flows']: flow.layers[ 0].initialized = True # Activation_Norm is already initialized when checkpoint is loaded. logging.info('Checkpoint loaded at {} steps.'.format(self.steps)) if 'GE2E' in self.model_Dict['GlowTTS'].layer_Dict.keys( ) and self.steps == 0: self.GE2E_Load_Checkpoint() def Save_Checkpoint(self): os.makedirs(hp.Checkpoint_Path, exist_ok=True) state_Dict = { 'Model': self.model_Dict['GlowTTS'].state_dict(), 'Optimizer': self.optimizer.state_dict(), 'Scheduler': self.scheduler.state_dict(), 'Steps': self.steps, 'Epochs': self.epochs, } if hp.Use_Mixed_Precision: state_Dict['AMP'] = amp.state_dict() torch.save( state_Dict, os.path.join(hp.Checkpoint_Path, 'S_{}.pt'.format(self.steps).replace('\\', '/'))) logging.info('Checkpoint saved at {} steps.'.format(self.steps)) def GE2E_Load_Checkpoint(self): state_Dict = torch.load(hp.Speaker_Embedding.GE2E.Checkpoint_Path, map_location='cpu') self.model_Dict['GlowTTS'].layer_Dict['GE2E'].load_state_dict( state_Dict['Model']) logging.info('Speaker embedding checkpoint \'{}\' loaded.'.format( hp.Speaker_Embedding.GE2E.Checkpoint_Path)) def Train(self): hp_Path = os.path.join(hp.Checkpoint_Path, 'Hyper_Parameters.yaml').replace('\\', '/') if not os.path.exists(hp_Path): from shutil import copyfile os.makedirs(hp.Checkpoint_Path, exist_ok=True) copyfile('Hyper_Parameters.yaml', hp_Path) if self.steps == 0: self.Evaluation_Epoch() if hp.Train.Initial_Inference: self.Inference_Epoch() self.tqdm = tqdm(initial=self.steps, total=hp.Train.Max_Step, desc='[Training]') while self.steps < hp.Train.Max_Step: try: self.Train_Epoch() except KeyboardInterrupt: self.Save_Checkpoint() exit(1) self.tqdm.close() logging.info('Finished training.')
class Trainer: def __init__(self, hp_path, steps=0): self.hp_Path = hp_path self.gpu_id = int(os.getenv('RANK', '0')) self.num_gpus = int(os.getenv("WORLD_SIZE", '1')) self.hp = Recursive_Parse( yaml.load(open(self.hp_Path, encoding='utf-8'), Loader=yaml.Loader)) if not torch.cuda.is_available(): self.device = torch.device('cpu') else: self.device = torch.device('cuda:{}'.format(self.gpu_id)) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.cuda.set_device(self.gpu_id) self.steps = steps self.Datset_Generate() self.Model_Generate() self.Load_Checkpoint() self._Set_Distribution() self.scalar_dict = { 'Train': defaultdict(float), 'Evaluation': defaultdict(float), } self.writer_Dict = { 'Train': Logger(os.path.join(self.hp.Log_Path, 'Train')), 'Evaluation': Logger(os.path.join(self.hp.Log_Path, 'Evaluation')), } def Datset_Generate(self): token_dict = yaml.load(open(self.hp.Token_Path), Loader=yaml.Loader) train_dataset = Dataset( token_dict=token_dict, pattern_path=self.hp.Train.Train_Pattern.Path, metadata_file=self.hp.Train.Train_Pattern.Metadata_File) dev_dataset = Dataset( token_dict=token_dict, pattern_path=self.hp.Train.Eval_Pattern.Path, metadata_file=self.hp.Train.Eval_Pattern.Metadata_File) inference_dataset = Inference_Dataset( token_dict=token_dict, pattern_paths=self.hp.Train.Inference_Pattern_in_Train) if self.gpu_id == 0: logging.info('The number of train patterns = {}.'.format( len(train_dataset))) logging.info('The number of development patterns = {}.'.format( len(dev_dataset))) logging.info('The number of inference patterns = {}.'.format( len(inference_dataset))) collater = Collater(token_dict=token_dict) inference_collater = Inference_Collater(token_dict=token_dict) self.dataloader_dict = {} self.dataloader_dict['Train'] = torch.utils.data.DataLoader( dataset= train_dataset, sampler= torch.utils.data.DistributedSampler(train_dataset, shuffle= True) \ if self.hp.Use_Multi_GPU else \ torch.utils.data.RandomSampler(train_dataset), collate_fn= collater, batch_size= self.hp.Train.Batch_Size, num_workers= self.hp.Train.Num_Workers, pin_memory= True ) self.dataloader_dict['Dev'] = torch.utils.data.DataLoader( dataset= dev_dataset, sampler= torch.utils.data.DistributedSampler(dev_dataset, shuffle= True) \ if self.num_gpus > 1 else \ torch.utils.data.RandomSampler(dev_dataset), collate_fn= collater, batch_size= self.hp.Train.Batch_Size, num_workers= self.hp.Train.Num_Workers, pin_memory= True ) self.dataloader_dict['Inference'] = torch.utils.data.DataLoader( dataset=inference_dataset, sampler=torch.utils.data.SequentialSampler(inference_dataset), collate_fn=inference_collater, batch_size=self.hp.Inference_Batch_Size or self.hp.Train.Batch_Size, num_workers=self.hp.Train.Num_Workers, pin_memory=True) def Model_Generate(self): self.model = TacoSinger(self.hp).to(self.device) self.criterion_dict = { 'MSE': torch.nn.MSELoss().to(self.device), 'GAL': Guided_Attention_Loss(), } self.optimizer = RAdam(params=self.model.parameters(), lr=self.hp.Train.Learning_Rate.Initial, betas=(self.hp.Train.ADAM.Beta1, self.hp.Train.ADAM.Beta2), eps=self.hp.Train.ADAM.Epsilon, weight_decay=self.hp.Train.Weight_Decay) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=self.hp.Train.Learning_Rate.Base) self.scaler = torch.cuda.amp.GradScaler( enabled=self.hp.Use_Mixed_Precision) if self.gpu_id == 0: logging.info(self.model) def Train_Step(self, tokens, notes, durations, token_lengths, features, feature_lengths): loss_dict = {} tokens = tokens.to(self.device, non_blocking=True) notes = notes.to(self.device, non_blocking=True) durations = durations.to(self.device, non_blocking=True) token_lengths = token_lengths.to(self.device, non_blocking=True) features = features.to(self.device, non_blocking=True) feature_lengths = feature_lengths.to(self.device, non_blocking=True) with torch.cuda.amp.autocast(enabled=self.hp.Use_Mixed_Precision): pre_features, post_features, alignments = self.model( tokens=tokens, notes=notes, durations=durations, token_lengths=token_lengths, features=features, feature_lengths=feature_lengths, is_training=True) loss_dict['Pre'] = self.criterion_dict['MSE'](pre_features, features) loss_dict['Post'] = self.criterion_dict['MSE'](post_features, features) loss_dict['Guided_Attention'] = self.criterion_dict['GAL']( alignments, feature_lengths, token_lengths) loss_dict['Total'] = loss_dict['Pre'] + loss_dict[ 'Post'] + loss_dict['Guided_Attention'] self.optimizer.zero_grad() self.scaler.scale(loss_dict['Total']).backward() if self.hp.Train.Gradient_Norm > 0.0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( parameters=self.model.parameters(), max_norm=self.hp.Train.Gradient_Norm) self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() self.steps += 1 self.tqdm.update(1) for tag, loss in loss_dict.items(): loss = reduce_tensor( loss.data, self.num_gpus).item() if self.num_gpus > 1 else loss.item() self.scalar_dict['Train']['Loss/{}'.format(tag)] += loss def Train_Epoch(self): for tokens, notes, durations, token_lengths, features, feature_lengths in self.dataloader_dict[ 'Train']: self.Train_Step(tokens, notes, durations, token_lengths, features, feature_lengths) if self.steps % self.hp.Train.Checkpoint_Save_Interval == 0: self.Save_Checkpoint() if self.steps % self.hp.Train.Logging_Interval == 0: self.scalar_dict['Train'] = { tag: loss / self.hp.Train.Logging_Interval for tag, loss in self.scalar_dict['Train'].items() } self.scalar_dict['Train'][ 'Learning_Rate'] = self.scheduler.get_last_lr() self.writer_Dict['Train'].add_scalar_dict( self.scalar_dict['Train'], self.steps) self.scalar_dict['Train'] = defaultdict(float) if self.steps % self.hp.Train.Evaluation_Interval == 0: self.Evaluation_Epoch() if self.steps % self.hp.Train.Inference_Interval == 0: self.Inference_Epoch() if self.steps >= self.hp.Train.Max_Step: return @torch.no_grad() def Evaluation_Step(self, tokens, notes, durations, token_lengths, features, feature_lengths): loss_dict = {} tokens = tokens.to(self.device, non_blocking=True) notes = notes.to(self.device, non_blocking=True) durations = durations.to(self.device, non_blocking=True) token_lengths = token_lengths.to(self.device, non_blocking=True) features = features.to(self.device, non_blocking=True) feature_lengths = feature_lengths.to(self.device, non_blocking=True) pre_features, post_features, alignments = self.model( tokens=tokens, notes=notes, durations=durations, token_lengths=token_lengths, features=features, feature_lengths=feature_lengths, is_training=True) loss_dict['Pre'] = self.criterion_dict['MSE'](pre_features, features) loss_dict['Post'] = self.criterion_dict['MSE'](post_features, features) loss_dict['Guided_Attention'] = self.criterion_dict['GAL']( alignments, feature_lengths, token_lengths) loss_dict['Total'] = loss_dict['Pre'] + loss_dict['Post'] + loss_dict[ 'Guided_Attention'] for tag, loss in loss_dict.items(): loss = reduce_tensor( loss.data, self.num_gpus).item() if self.num_gpus > 1 else loss.item() self.scalar_dict['Evaluation']['Loss/{}'.format(tag)] += loss return pre_features, post_features, alignments def Evaluation_Epoch(self): logging.info('(Steps: {}) Start evaluation in GPU {}.'.format( self.steps, self.gpu_id)) self.model.eval() for step, (tokens, notes, durations, token_lengths, features, feature_lengths) in tqdm( enumerate(self.dataloader_dict['Dev'], 1), desc='[Evaluation]', total=math.ceil( len(self.dataloader_dict['Dev'].dataset) / self.hp.Train.Batch_Size / self.num_gpus)): pre_features, post_features, alignments = self.Evaluation_Step( tokens, notes, durations, token_lengths, features, feature_lengths) if self.gpu_id == 0: self.scalar_dict['Evaluation'] = { tag: loss / step for tag, loss in self.scalar_dict['Evaluation'].items() } self.writer_Dict['Evaluation'].add_scalar_dict( self.scalar_dict['Evaluation'], self.steps) self.writer_Dict['Evaluation'].add_histogram_model( self.model, 'TacoSinger', self.steps, delete_keywords=[]) image_dict = { 'Feature/Target': (features[-1].cpu().numpy(), None, 'auto', None), 'Feature/Pre': (pre_features[-1].cpu().numpy(), None, 'auto', None), 'Feature/Post': (post_features[-1].cpu().numpy(), None, 'auto', None), 'Alignment': (alignments[-1].cpu().numpy(), None, 'auto', None) } self.writer_Dict['Evaluation'].add_image_dict( image_dict, self.steps) self.scalar_dict['Evaluation'] = defaultdict(float) self.model.train() @torch.no_grad() def Inference_Step(self, tokens, notes, durations, token_lengths, feature_lengths, texts, decomposed_texts, restoring, start_index=0, tag_step=False): tokens = tokens.to(self.device, non_blocking=True) notes = notes.to(self.device, non_blocking=True) durations = durations.to(self.device, non_blocking=True) token_lengths = token_lengths.to(self.device, non_blocking=True) feature_lengths = feature_lengths.to(self.device, non_blocking=True) _, post_features, alignments = self.model( tokens=tokens, notes=notes, durations=durations, token_lengths=token_lengths, feature_lengths=feature_lengths, is_training=False) post_features = torch.stack( [post_features[index] for index in restoring], dim=0) alignments = torch.stack([alignments[index] for index in restoring], dim=0) feature_lengths = torch.stack( [feature_lengths[index] for index in restoring], dim=0) texts = [texts[index] for index in restoring] decomposed_texts = [decomposed_texts[index] for index in restoring] audios = [] for feature, length in zip(post_features, feature_lengths): feature = spectral_de_normalize_torch( feature[:, :length]).cpu().numpy() audio = griffinlim(feature) audios.append(audio) audios = [(audio / np.abs(audio).max() * 32767.5).astype(np.int16) for audio in audios] files = [] for index in range(post_features.size(0)): tags = [] if tag_step: tags.append('Step-{}'.format(self.steps)) tags.append('IDX_{}'.format(index + start_index)) files.append('.'.join(tags)) os.makedirs(os.path.join(self.hp.Inference_Path, 'Step-{}'.format(self.steps), 'PNG').replace('\\', '/'), exist_ok=True) os.makedirs(os.path.join(self.hp.Inference_Path, 'Step-{}'.format(self.steps), 'WAV').replace('\\', '/'), exist_ok=True) for index, (feature, alignment, text, decomposed_text, audio, file) in enumerate( zip(post_features.cpu().numpy(), alignments.cpu().numpy(), texts, decomposed_texts, audios, files)): title = 'Text: {}'.format(text if len(text) < 90 else text[:90] + '…') new_Figure = plt.figure(figsize=(20, 5 * 3), dpi=100) plt.subplot2grid((4, 1), (0, 0)) plt.imshow(feature, aspect='auto', origin='lower') plt.title('Feature {}'.format(title)) plt.colorbar() plt.subplot2grid((4, 1), (1, 0), rowspan=2) plt.imshow(alignment[:len(decomposed_text) + 2], aspect='auto', origin='lower') plt.title('Alignment {}'.format(title)) plt.yticks(range(len(decomposed_text)), list(decomposed_text), fontsize=10) plt.colorbar() plt.tight_layout() plt.savefig( os.path.join(self.hp.Inference_Path, 'Step-{}'.format(self.steps), 'PNG', '{}.png'.format(file)).replace('\\', '/')) plt.close(new_Figure) wavfile.write( os.path.join(self.hp.Inference_Path, 'Step-{}'.format(self.steps), 'WAV', '{}.wav'.format(file)).replace('\\', '/'), self.hp.Sound.Sample_Rate, audio) def Inference_Epoch(self): if self.gpu_id != 0: return logging.info('(Steps: {}) Start inference.'.format(self.steps)) self.model.eval() batch_size = self.hp.Inference_Batch_Size or self.hp.Train.Batch_Size for step, (tokens, notes, durations, token_lengths, feature_lengths, texts, decomposed_texts, restoring) in tqdm( enumerate(self.dataloader_dict['Inference']), desc='[Inference]', total=math.ceil( len(self.dataloader_dict['Inference'].dataset) / batch_size)): self.Inference_Step(tokens, notes, durations, token_lengths, feature_lengths, texts, decomposed_texts, restoring, start_index=step * batch_size) self.model.train() def Load_Checkpoint(self): if self.steps == 0: paths = [ os.path.join(root, file).replace('\\', '/') for root, _, files in os.walk(self.hp.Checkpoint_Path) for file in files if os.path.splitext(file)[1] == '.pt' ] if len(paths) > 0: path = max(paths, key=os.path.getctime) else: return # Initial training else: path = os.path.join( self.hp.Checkpoint_Path, 'S_{}.pt'.format(self.steps).replace('\\', '/')) state_dict = torch.load(path, map_location='cpu') self.model.load_state_dict(state_dict['Model']) self.optimizer.load_state_dict(state_dict['Optimizer']) self.scheduler.load_state_dict(state_dict['Scheduler']) self.steps = state_dict['Steps'] logging.info('Checkpoint loaded at {} steps in GPU {}.'.format( self.steps, self.gpu_id)) def Save_Checkpoint(self): if self.gpu_id != 0: return os.makedirs(self.hp.Checkpoint_Path, exist_ok=True) state_Dict = { 'Model': self.model.state_dict(), 'Optimizer': self.optimizer.state_dict(), 'Scheduler': self.scheduler.state_dict(), 'Steps': self.steps } torch.save( state_Dict, os.path.join(self.hp.Checkpoint_Path, 'S_{}.pt'.format(self.steps).replace('\\', '/'))) logging.info('Checkpoint saved at {} steps.'.format(self.steps)) def _Set_Distribution(self): if self.num_gpus > 1: self.model = apply_gradient_allreduce(self.model) def Train(self): hp_path = os.path.join(self.hp.Checkpoint_Path, 'Hyper_Parameters.yaml').replace('\\', '/') if not os.path.exists(hp_path): os.makedirs(self.hp.Checkpoint_Path, exist_ok=True) yaml.dump(self.hp, open(hp_path, 'w')) if self.steps == 0: self.Evaluation_Epoch() if self.hp.Train.Initial_Inference: self.Inference_Epoch() self.tqdm = tqdm(initial=self.steps, total=self.hp.Train.Max_Step, desc='[Training]') while self.steps < self.hp.Train.Max_Step: try: self.Train_Epoch() except KeyboardInterrupt: self.Save_Checkpoint() exit(1) self.tqdm.close() logging.info('Finished training.')
class Trainer: def __init__(self, steps=0): self.steps = steps self.epochs = 0 self.Datset_Generate() self.Model_Generate() self.scalar_Dict = { 'Train': defaultdict(float), 'Evaluation': defaultdict(float), } self.writer_Dict = { 'Train': Logger(os.path.join(hp_Dict['Log_Path'], 'Train')), 'Evaluation': Logger(os.path.join(hp_Dict['Log_Path'], 'Evaluation')), } self.Load_Checkpoint() def Datset_Generate(self): train_Dataset = Train_Dataset() accumulation_Dataset = Accumulation_Dataset() dev_Dataset = Dev_Dataset() inference_Dataset = Inference_Dataset() logging.info('The number of base train files = {}.'.format( len(train_Dataset) // hp_Dict['Train']['Train_Pattern']['Accumulated_Dataset_Epoch'])) logging.info('The number of development patterns = {}.'.format( len(dev_Dataset))) logging.info('The number of inference patterns = {}.'.format( len(inference_Dataset))) collater = Collater() accumulation_Collater = Accumulation_Collater() inference_Collater = Inference_Collater() self.dataLoader_Dict = {} self.dataLoader_Dict['Train'] = torch.utils.data.DataLoader( dataset=train_Dataset, shuffle=True, collate_fn=collater, batch_size=hp_Dict['Train']['Batch_Size'], num_workers=hp_Dict['Train']['Num_Workers'], pin_memory=True) self.dataLoader_Dict['Accumulation'] = torch.utils.data.DataLoader( dataset=accumulation_Dataset, shuffle=False, collate_fn=accumulation_Collater, batch_size=hp_Dict['Train']['Batch_Size'], num_workers=hp_Dict['Train']['Num_Workers'], pin_memory=True) self.dataLoader_Dict['Dev'] = torch.utils.data.DataLoader( dataset=dev_Dataset, shuffle=True, # to write tensorboard. collate_fn=collater, batch_size=hp_Dict['Train']['Batch_Size'], num_workers=hp_Dict['Train']['Num_Workers'], pin_memory=True) self.dataLoader_Dict['Inference'] = torch.utils.data.DataLoader( dataset=inference_Dataset, shuffle=False, # to write tensorboard. collate_fn=inference_Collater, batch_size=hp_Dict['Inference_Batch_Size'] or hp_Dict['Train']['Batch_Size'], num_workers=hp_Dict['Train']['Num_Workers'], pin_memory=True) def Model_Generate(self): self.model = PVCGAN().to(device) self.criterion_Dict = { 'STFT': MultiResolutionSTFTLoss( fft_sizes=hp_Dict['STFT_Loss_Resolution']['FFT_Sizes'], shift_lengths=hp_Dict['STFT_Loss_Resolution']['Shfit_Lengths'], win_lengths=hp_Dict['STFT_Loss_Resolution']['Win_Lengths'], ).to(device), 'MSE': torch.nn.MSELoss().to(device), 'CE': torch.nn.CrossEntropyLoss().to(device), 'MAE': torch.nn.L1Loss().to(device) } self.optimizer = RAdam(params=self.model.parameters(), lr=hp_Dict['Train']['Learning_Rate']['Initial'], betas=(hp_Dict['Train']['ADAM']['Beta1'], hp_Dict['Train']['ADAM']['Beta2']), eps=hp_Dict['Train']['ADAM']['Epsilon'], weight_decay=hp_Dict['Train']['Weight_Decay']) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=hp_Dict['Train']['Learning_Rate']['Base']) if hp_Dict['Use_Mixed_Precision']: self.model, self.optimizer = amp.initialize( models=self.model, optimizers=self.optimizer) logging.info(self.model) def Train_Step(self, audios, mels, pitches, audio_Singers, mel_Singers, noises): loss_Dict = {} audios = audios.to(device, non_blocking=True) mels = mels.to(device, non_blocking=True) pitches = pitches.to(device, non_blocking=True) audio_Singers = audio_Singers.to(device, non_blocking=True) mel_Singers = mel_Singers.to(device, non_blocking=True) noises = noises.to(device, non_blocking=True) # Generator fakes, predicted_Singers, predicted_Pitches, fakes_Discriminations, reals_Discriminations = self.model( mels=mels, pitches=pitches, singers=audio_Singers, noises=noises, discrimination=self.steps >= hp_Dict['Train']['Discriminator_Delay'], reals=audios) loss_Dict['Generator/Spectral_Convergence'], loss_Dict[ 'Generator/Magnitude'] = self.criterion_Dict['STFT'](fakes, audios) loss_Dict['Generator'] = loss_Dict[ 'Generator/Spectral_Convergence'] + loss_Dict['Generator/Magnitude'] loss_Dict['Confuser/Singer'] = self.criterion_Dict['CE']( predicted_Singers, mel_Singers) loss_Dict['Confuser/Pitch'] = self.criterion_Dict['MAE']( predicted_Pitches, pitches) loss_Dict['Confuser'] = loss_Dict['Confuser/Singer'] + loss_Dict[ 'Confuser/Pitch'] loss = loss_Dict['Generator'] + loss_Dict['Confuser'] if self.steps >= hp_Dict['Train']['Discriminator_Delay']: loss_Dict['Discriminator/Fake'] = self.criterion_Dict['MSE']( fakes_Discriminations, torch.zeros_like(fakes_Discriminations) ) # Discriminator thinks that 0 is correct. loss_Dict['Discriminator/Real'] = self.criterion_Dict['MSE']( reals_Discriminations, torch.ones_like(reals_Discriminations) ) # Discriminator thinks that 1 is correct. loss_Dict['Discriminator'] = loss_Dict[ 'Discriminator/Fake'] + loss_Dict['Discriminator/Real'] loss += loss_Dict['Discriminator'] self.optimizer.zero_grad() if hp_Dict['Use_Mixed_Precision']: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_( parameters=amp.master_params(self.optimizer), max_norm=hp_Dict['Train']['Gradient_Norm']) else: loss.backward() torch.nn.utils.clip_grad_norm_( parameters=self.model.parameters(), max_norm=hp_Dict['Train']['Gradient_Norm']) self.optimizer.step() self.scheduler.step() self.steps += 1 self.tqdm.update(1) for tag, loss in loss_Dict.items(): self.scalar_Dict['Train']['Loss/{}'.format(tag)] += loss def Train_Epoch(self): if any([ hp_Dict['Train']['Train_Pattern']['Mixup']['Use'] and self.steps >= hp_Dict['Train']['Train_Pattern']['Mixup']['Apply_Delay'], hp_Dict['Train']['Train_Pattern']['Back_Translate']['Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'] ['Back_Translate']['Apply_Delay'] ]): self.Data_Accumulation() for audios, mels, pitches, audio_Singers, mel_Singers, noises in self.dataLoader_Dict[ 'Train']: self.Train_Step(audios, mels, pitches, audio_Singers, mel_Singers, noises) if self.steps % hp_Dict['Train']['Checkpoint_Save_Interval'] == 0: self.Save_Checkpoint() if self.steps % hp_Dict['Train']['Logging_Interval'] == 0: self.scalar_Dict['Train'] = { tag: loss / hp_Dict['Train']['Logging_Interval'] for tag, loss in self.scalar_Dict['Train'].items() } self.scalar_Dict['Train'][ 'Learning_Rate'] = self.scheduler.get_last_lr() self.writer_Dict['Train'].add_scalar_dict( self.scalar_Dict['Train'], self.steps) self.scalar_Dict['Train'] = defaultdict(float) if self.steps % hp_Dict['Train']['Evaluation_Interval'] == 0: self.Evaluation_Epoch() if self.steps % hp_Dict['Train']['Inference_Interval'] == 0: self.Inference_Epoch() if self.steps >= hp_Dict['Train']['Max_Step']: return self.epochs += hp_Dict['Train']['Train_Pattern'][ 'Accumulated_Dataset_Epoch'] @torch.no_grad() def Evaluation_Step(self, audios, mels, pitches, audio_Singers, mel_Singers, noises): loss_Dict = {} audios = audios.to(device, non_blocking=True) mels = mels.to(device, non_blocking=True) pitches = pitches.to(device, non_blocking=True) audio_Singers = audio_Singers.to(device, non_blocking=True) mel_Singers = mel_Singers.to(device, non_blocking=True) noises = noises.to(device, non_blocking=True) # Generator fakes, predicted_Singers, predicted_Pitches, fakes_Discriminations, reals_Discriminations = self.model( mels=mels, pitches=pitches, singers=audio_Singers, noises=noises, discrimination=self.steps >= hp_Dict['Train']['Discriminator_Delay'], reals=audios) loss_Dict['Generator/Spectral_Convergence'], loss_Dict[ 'Generator/Magnitude'] = self.criterion_Dict['STFT'](fakes, audios) loss_Dict['Generator'] = loss_Dict[ 'Generator/Spectral_Convergence'] + loss_Dict['Generator/Magnitude'] loss_Dict['Confuser/Singer'] = self.criterion_Dict['CE']( predicted_Singers, mel_Singers) loss_Dict['Confuser/Pitch'] = self.criterion_Dict['MAE']( predicted_Pitches, pitches) loss_Dict['Confuser'] = loss_Dict['Confuser/Singer'] + loss_Dict[ 'Confuser/Pitch'] loss = loss_Dict['Generator'] + loss_Dict['Confuser'] if self.steps >= hp_Dict['Train']['Discriminator_Delay']: loss_Dict['Discriminator/Fake'] = self.criterion_Dict['MSE']( fakes_Discriminations, torch.zeros_like(fakes_Discriminations) ) # Discriminator thinks that 0 is correct. loss_Dict['Discriminator/Real'] = self.criterion_Dict['MSE']( reals_Discriminations, torch.ones_like(reals_Discriminations) ) # Discriminator thinks that 1 is correct. loss_Dict['Discriminator'] = loss_Dict[ 'Discriminator/Fake'] + loss_Dict['Discriminator/Real'] loss += loss_Dict['Discriminator'] for tag, loss in loss_Dict.items(): self.scalar_Dict['Evaluation']['Loss/{}'.format(tag)] += loss return fakes, predicted_Singers, predicted_Pitches def Evaluation_Epoch(self): logging.info('(Steps: {}) Start evaluation.'.format(self.steps)) self.model.eval() for step, (audios, mels, pitches, audio_Singers, mel_Singers, noises) in tqdm( enumerate(self.dataLoader_Dict['Dev'], 1), desc='[Evaluation]', total=math.ceil( len(self.dataLoader_Dict['Dev'].dataset) / hp_Dict['Train']['Batch_Size'])): fakes, predicted_Singers, predicted_Pitches = self.Evaluation_Step( audios, mels, pitches, audio_Singers, mel_Singers, noises) self.scalar_Dict['Evaluation'] = { tag: loss / step for tag, loss in self.scalar_Dict['Evaluation'].items() } self.writer_Dict['Evaluation'].add_scalar_dict( self.scalar_Dict['Evaluation'], self.steps) self.writer_Dict['Evaluation'].add_histogram_model( self.model, self.steps, delete_keywords=['layer_Dict', '1', 'layer']) self.scalar_Dict['Evaluation'] = defaultdict(float) self.writer_Dict['Evaluation'].add_image_dict( { 'Mel': (mels[-1].cpu().numpy(), None), 'Audio/Original': (audios[-1].cpu().numpy(), None), 'Audio/Predicted': (fakes[-1].cpu().numpy(), None), 'Pitch/Original': (pitches[-1].cpu().numpy(), None), 'Pitch/Predicted': (predicted_Pitches[-1].cpu().numpy(), None), 'Singer/Original': (torch.nn.functional.one_hot( mel_Singers, hp_Dict['Num_Singers']).cpu().numpy(), None), 'Singer/Predicted': (torch.softmax(predicted_Singers, dim=-1).cpu().numpy(), None), }, self.steps) self.model.train() @torch.no_grad() def Inference_Step(self, audios, mels, pitches, singers, noises, source_Labels, singer_Labels, start_Index=0, tag_step=False, tag_Index=False): audios = audios.to(device, non_blocking=True) mels = mels.to(device, non_blocking=True) pitches = pitches.to(device, non_blocking=True) singers = singers.to(device, non_blocking=True) noises = noises.to(device, non_blocking=True) fakes, *_ = self.model(mels=mels, pitches=pitches, singers=singers, noises=noises) files = [] for index, (source_Label, singer_Label) in enumerate( zip(source_Labels, singer_Labels)): tags = [] if tag_step: tags.append('Step-{}'.format(self.steps)) tags.append('{}_to_{}'.format(source_Label, singer_Label)) if tag_Index: tags.append('IDX_{}'.format(index + start_Index)) files.append('.'.join(tags)) os.makedirs(os.path.join(hp_Dict['Inference_Path'], 'Step-{}'.format(self.steps), 'PNG').replace('\\', '/'), exist_ok=True) os.makedirs(os.path.join(hp_Dict['Inference_Path'], 'Step-{}'.format(self.steps), 'WAV').replace("\\", "/"), exist_ok=True) for index, (real, fake, source_Label, singer_Label, file) in enumerate( zip(audios.cpu().numpy(), fakes.cpu().numpy(), source_Labels, singer_Labels, files)): new_Figure = plt.figure(figsize=(80, 10 * 2), dpi=100) plt.subplot(211) plt.plot(real) plt.title('Original wav Index: {} {} -> {}'.format( index + start_Index, source_Label, singer_Label)) plt.subplot(212) plt.plot(fake) plt.title('Fake wav Index: {} {} -> {}'.format( index + start_Index, source_Label, singer_Label)) plt.tight_layout() plt.savefig( os.path.join(hp_Dict['Inference_Path'], 'Step-{}'.format(self.steps), 'PNG', '{}.png'.format(file)).replace('\\', '/')) plt.close(new_Figure) wavfile.write(filename=os.path.join( hp_Dict['Inference_Path'], 'Step-{}'.format(self.steps), 'WAV', '{}.wav'.format(file)).replace('\\', '/'), data=(np.clip(fake, -1.0 + 1e-7, 1.0 - 1e-7) * 32767.5).astype(np.int16), rate=hp_Dict['Sound']['Sample_Rate']) def Inference_Epoch(self): logging.info('(Steps: {}) Start inference.'.format(self.steps)) self.model.eval() for step, (audios, mels, pitches, singers, noises, source_Labels, singer_Labels) in tqdm( enumerate(self.dataLoader_Dict['Inference'], 1), desc='[Inference]', total=math.ceil( len(self.dataLoader_Dict['Inference'].dataset) / (hp_Dict['Inference_Batch_Size'] or hp_Dict['Train']['Batch_Size']))): self.Inference_Step(audios, mels, pitches, singers, noises, source_Labels, singer_Labels, start_Index=(step - 1) * hp_Dict['Train']['Batch_Size']) self.model.train() @torch.no_grad() def Back_Translate_Step(self, mels, pitches, singers, noises): mels = mels.to(device, non_blocking=True) pitches = pitches.to(device, non_blocking=True) singers = singers.to(device, non_blocking=True) noises = noises.to(device, non_blocking=True) fakes, *_ = self.model(mels=mels, pitches=pitches, singers=singers, noises=noises) return fakes.cpu().numpy() def Data_Accumulation(self): def Mixup(audio, pitch): max_Offset = pitch.shape[0] - hp_Dict['Train'][ 'Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2 offset1 = np.random.randint(0, max_Offset) offset2 = np.random.randint(0, max_Offset) beta = np.random.uniform( low=hp_Dict['Train']['Train_Pattern']['Mixup']['Min_Beta'], high=hp_Dict['Train']['Train_Pattern']['Mixup']['Max_Beta'], ) new_Audio = \ beta * audio[offset1 * hp_Dict['Sound']['Frame_Shift']:offset1 * hp_Dict['Sound']['Frame_Shift'] + hp_Dict['Train']['Wav_Length'] * 2] + \ (1 - beta) * audio[offset2* hp_Dict['Sound']['Frame_Shift']:offset2 * hp_Dict['Sound']['Frame_Shift'] + hp_Dict['Train']['Wav_Length'] * 2] new_Pitch = \ beta * pitch[offset1:offset1 + hp_Dict['Train']['Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2] + \ (1 - beta) * pitch[offset2:offset2 + hp_Dict['Train']['Wav_Length'] // hp_Dict['Sound']['Frame_Shift'] * 2] _, new_Mel, _, _ = Pattern_Generate(audio=new_Audio) return new_Audio, new_Mel, new_Pitch def Back_Translate(mels, pitches, singers, noises): fakes = self.Back_Translate_Step(mels=mels, pitches=pitches, singers=singers, noises=noises) new_Mels = [Pattern_Generate(audio=fake)[1] for fake in fakes] return new_Mels print() mixup_List = [] back_Translate_List = [] for total_Audios, total_Pitches, audios, mels, pitches, singers, noises in tqdm( self.dataLoader_Dict['Accumulation'], desc='[Accumulation]', total=math.ceil( len(self.dataLoader_Dict['Accumulation'].dataset) / hp_Dict['Train']['Batch_Size'])): #Mixup if hp_Dict['Train']['Train_Pattern']['Mixup'][ 'Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'][ 'Mixup']['Apply_Delay']: for audio, pitch, singer in zip(total_Audios, total_Pitches, singers.numpy()): mixup_Audio, mixup_Mel, mixup_Pitch = Mixup(audio, pitch) mixup_List.append( (mixup_Audio, mixup_Mel, mixup_Pitch, singer, singer)) #Backtranslate if hp_Dict['Train']['Train_Pattern']['Back_Translate'][ 'Use'] and self.steps >= hp_Dict['Train']['Train_Pattern'][ 'Back_Translate']['Apply_Delay']: mel_Singers = torch.LongTensor( np.stack([ choice([ x for x in range(hp_Dict['Num_Singers']) if x != singer ]) for singer in singers ], axis=0)) back_Translate_Mels = Back_Translate(mels, pitches, mel_Singers, noises) for audio, back_Translate_Mel, pitch, audio_Singer, mel_Singer in zip( audios.numpy(), back_Translate_Mels, pitches.numpy(), singers.numpy(), mel_Singers.numpy()): back_Translate_List.append( (audio, back_Translate_Mel, pitch, audio_Singer, mel_Singer)) self.dataLoader_Dict['Train'].dataset.Accumulation_Renew( mixup_Pattern_List=mixup_List, back_Translate_Pattern_List=back_Translate_List) def Load_Checkpoint(self): if self.steps == 0: paths = [ os.path.join(root, file).replace('\\', '/') for root, _, files in os.walk(hp_Dict['Checkpoint_Path']) for file in files if os.path.splitext(file)[1] == '.pt' ] if len(paths) > 0: path = max(paths, key=os.path.getctime) else: return # Initial training else: path = os.path.join( hp_Dict['Checkpoint_Path'], 'S_{}.pt'.format(self.steps).replace('\\', '/')) state_Dict = torch.load(path, map_location='cpu') self.model.load_state_dict(state_Dict['Model']) self.optimizer.load_state_dict(state_Dict['Optimizer']) self.scheduler.load_state_dict(state_Dict['Scheduler']) self.steps = state_Dict['Steps'] self.epochs = state_Dict['Epochs'] if hp_Dict['Use_Mixed_Precision']: if not 'AMP' in state_Dict.keys(): logging.info( 'No AMP state dict is in the checkpoint. Model regards this checkpoint is trained without mixed precision.' ) else: amp.load_state_dict(state_Dict['AMP']) logging.info('Checkpoint loaded at {} steps.'.format(self.steps)) def Save_Checkpoint(self): os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True) state_Dict = { 'Model': self.model.state_dict(), 'Optimizer': self.optimizer.state_dict(), 'Scheduler': self.scheduler.state_dict(), 'Steps': self.steps, 'Epochs': self.epochs, } if hp_Dict['Use_Mixed_Precision']: state_Dict['AMP'] = amp.state_dict() torch.save( state_Dict, os.path.join(hp_Dict['Checkpoint_Path'], 'S_{}.pt'.format(self.steps).replace('\\', '/'))) logging.info('Checkpoint saved at {} steps.'.format(self.steps)) def Train(self): if not os.path.exists( os.path.join(hp_Dict['Checkpoint_Path'], 'Hyper_Parameters.yaml')): os.makedirs(hp_Dict['Checkpoint_Path'], exist_ok=True) with open( os.path.join(hp_Dict['Checkpoint_Path'], 'Hyper_Parameters.yaml').replace("\\", "/"), "w") as f: yaml.dump(hp_Dict, f) if self.steps == 0: self.Evaluation_Epoch() if hp_Dict['Train']['Initial_Inference']: self.Inference_Epoch() self.tqdm = tqdm(initial=self.steps, total=hp_Dict['Train']['Max_Step'], desc='[Training]') while self.steps < hp_Dict['Train']['Max_Step']: try: self.Train_Epoch() except KeyboardInterrupt: self.Save_Checkpoint() exit(1) self.tqdm.close() logging.info('Finished training.')
class Trainer: def __init__(self, hp_path, steps=0): self.hp_path = hp_path self.gpu_id = int(os.getenv('RANK', '0')) self.num_gpus = int(os.getenv("WORLD_SIZE", '1')) self.hp = Recursive_Parse( yaml.load(open(hp_path, encoding='utf-8'), Loader=yaml.Loader)) if not torch.cuda.is_available(): self.device = torch.device('cpu') else: self.device = torch.device('cuda:{}'.format(self.gpu_id)) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.cuda.set_device(self.gpu_id) self.steps = steps self.Dataset_Generate() self.Model_Generate() self.Load_Checkpoint() self._Set_Distribution() self.scalar_dict = { 'Train': defaultdict(float), 'Evaluation': defaultdict(float), } self.writer_dict = { 'Train': Logger(os.path.join(self.hp.Log_Path, 'Train')), 'Evaluation': Logger(os.path.join(self.hp.Log_Path, 'Evaluation')), } def Dataset_Generate(self): train_dataset = Dataset( pattern_path=self.hp.Train.Train_Pattern.Path, metadata_file=self.hp.Train.Train_Pattern.Metadata_File, pattern_per_speaker=self.hp.Train.Batch.Train.Pattern_per_Speaker) dev_dataset = Dataset( pattern_path=self.hp.Train.Eval_Pattern.Path, metadata_file=self.hp.Train.Eval_Pattern.Metadata_File, pattern_per_speaker=self.hp.Train.Batch.Eval.Pattern_per_Speaker) inference_dataset = Dataset( pattern_path=self.hp.Train.Eval_Pattern.Path, metadata_file=self.hp.Train.Eval_Pattern.Metadata_File, pattern_per_speaker=self.hp.Train.Batch.Eval.Pattern_per_Speaker, num_speakers=50, #Maximum number by tensorboard. ) logging.info('The number of train speakers = {}.'.format( len(train_dataset))) logging.info('The number of development speakers = {}.'.format( len(dev_dataset))) collater = Collater(min_frame_length=self.hp.Train.Frame_Length.Min, max_frame_length=self.hp.Train.Frame_Length.Max) inference_collater = Inference_Collater( samples=self.hp.Train.Inference.Samples, frame_length=self.hp.Train.Inference.Frame_Length, overlap_length=self.hp.Train.Inference.Overlap_Length) self.dataloader_dict = {} self.dataloader_dict['Train'] = torch.utils.data.DataLoader( dataset= train_dataset, sampler= torch.utils.data.DistributedSampler(train_dataset, shuffle= True) \ if self.hp.Use_Multi_GPU else \ torch.utils.data.RandomSampler(train_dataset), collate_fn= collater, batch_size= self.hp.Train.Batch.Train.Speaker, num_workers= self.hp.Train.Num_Workers, pin_memory= True ) self.dataloader_dict['Dev'] = torch.utils.data.DataLoader( dataset= dev_dataset, sampler= torch.utils.data.DistributedSampler(dev_dataset, shuffle= True) \ if self.num_gpus > 1 else \ torch.utils.data.RandomSampler(dev_dataset), collate_fn= collater, batch_size= self.hp.Train.Batch.Eval.Speaker, num_workers= self.hp.Train.Num_Workers, pin_memory= True ) self.dataloader_dict['Inference'] = torch.utils.data.DataLoader( dataset=inference_dataset, shuffle=True, collate_fn=inference_collater, batch_size=self.hp.Train.Batch.Eval.Speaker, num_workers=self.hp.Train.Num_Workers, pin_memory=True) def Model_Generate(self): self.model = GE2E(self.hp).to(self.device) self.criterion = GE2E_Loss().to(self.device) self.optimizer = RAdam(params=self.model.parameters(), lr=self.hp.Train.Learning_Rate.Initial, betas=(self.hp.Train.ADAM.Beta1, self.hp.Train.ADAM.Beta2), eps=self.hp.Train.ADAM.Epsilon, weight_decay=self.hp.Train.Weight_Decay) self.scheduler = Modified_Noam_Scheduler( optimizer=self.optimizer, base=self.hp.Train.Learning_Rate.Base, ) self.scaler = torch.cuda.amp.GradScaler( enabled=self.hp.Use_Mixed_Precision) if self.gpu_id == 0: logging.info(self.model) def Train_Step(self, features): loss_dict = {} features = features.to(self.device, non_blocking=True) with torch.cuda.amp.autocast(enabled=self.hp.Use_Mixed_Precision): embeddings = self.model(features) loss_dict['Embedding'] = self.criterion( embeddings, self.hp.Train.Batch.Train.Pattern_per_Speaker) self.optimizer.zero_grad() self.scaler.scale(loss_dict['Embedding']).backward() if self.hp.Train.Gradient_Norm > 0.0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( parameters=self.model.parameters(), max_norm=self.hp.Train.Gradient_Norm) self.scaler.step(self.optimizer) self.scaler.update() self.scheduler.step() self.steps += 1 self.tqdm.update(1) for tag, loss in loss_dict.items(): loss = reduce_tensor( loss.data, self.num_gpus).item() if self.num_gpus > 1 else loss.item() self.scalar_dict['Train']['Loss/{}'.format(tag)] += loss def Train_Epoch(self): for features in self.dataloader_dict['Train']: self.Train_Step(features) if self.steps % self.hp.Train.Checkpoint_Save_Interval == 0: self.Save_Checkpoint() if self.steps % self.hp.Train.Logging_Interval == 0: self.scalar_dict['Train'] = { tag: loss / self.hp.Train.Logging_Interval for tag, loss in self.scalar_dict['Train'].items() } self.scalar_dict['Train'][ 'Learning_Rate'] = self.scheduler.get_last_lr() self.writer_dict['Train'].add_scalar_dict( self.scalar_dict['Train'], self.steps) self.scalar_dict['Train'] = defaultdict(float) if self.steps % self.hp.Train.Evaluation_Interval == 0: self.Evaluation_Epoch() if self.steps % self.hp.Train.Inference_Interval == 0: self.Inference_Epoch() if self.steps >= self.hp.Train.Max_Step: return @torch.no_grad() def Evaluation_Step(self, features): loss_dict = {} features = features.to(self.device, non_blocking=True) embeddings = self.model(features) loss_dict['Embedding'] = self.criterion( embeddings, self.hp.Train.Batch.Eval.Pattern_per_Speaker) for tag, loss in loss_dict.items(): loss = reduce_tensor( loss.data, self.num_gpus).item() if self.num_gpus > 1 else loss.item() self.scalar_dict['Evaluation']['Loss/{}'.format(tag)] += loss def Evaluation_Epoch(self): logging.info('(Steps: {}) Start evaluation in GPU {}.'.format( self.steps, self.gpu_id)) self.model.eval() for step, features in tqdm( enumerate(self.dataloader_dict['Dev'], 1), desc='[Evaluation]', total=math.ceil( len(self.dataloader_dict['Dev'].dataset) / self.hp.Train.Batch.Eval.Speaker / self.hp.Train.Batch.Eval.Pattern_per_Speaker)): self.Evaluation_Step(features) self.scalar_dict['Evaluation'] = { tag: loss / step for tag, loss in self.scalar_dict['Evaluation'].items() } self.writer_dict['Evaluation'].add_scalar_dict( self.scalar_dict['Evaluation'], self.steps) self.writer_dict['Evaluation'].add_histogram_model( self.model, 'GE2E', self.steps, delete_keywords=['layer_Dict', 'layer']) self.scalar_dict['Evaluation'] = defaultdict(float) self.model.train() @torch.no_grad() def Inference_Step(self, features): return self.model(features=features.to(self.device, non_blocking=True), samples=self.hp.Train.Inference.Samples) def Inference_Epoch(self): if self.gpu_id != 0: return logging.info('(Steps: {}) Start inference.'.format(self.steps)) self.model.eval() embeddings, speakers = zip( *[(self.Inference_Step(features), speakers) for features, speakers in tqdm(self.dataloader_dict['Inference'], desc='[Inference]')]) embeddings = torch.cat(embeddings, dim=0).cpu().numpy() speakers = [ speaker for speaker_list in speakers for speaker in speaker_list ] self.writer_dict['Evaluation'].add_embedding(embeddings, metadata=speakers, global_step=self.steps, tag='Embeddings') self.model.train() def Load_Checkpoint(self): if self.steps == 0: paths = [ os.path.join(root, file).replace('\\', '/') for root, _, files in os.walk(self.hp.Checkpoint_Path) for file in files if os.path.splitext(file)[1] == '.pt' ] if len(paths) > 0: path = max(paths, key=os.path.getctime) else: return # Initial training else: path = os.path.join( self.hp.Checkpoint_Path, 'S_{}.pt'.format(self.steps).replace('\\', '/')) state_Dict = torch.load(os.path.join(path), map_location='cpu') self.model.load_state_dict(state_Dict['Model']) self.optimizer.load_state_dict(state_Dict['Optimizer']) self.scheduler.load_state_dict(state_Dict['Scheduler']) self.steps = state_Dict['Steps'] logging.info('Checkpoint loaded at {} steps.'.format(self.steps)) def Save_Checkpoint(self): if self.gpu_id != 0: return os.makedirs(self.hp.Checkpoint_Path, exist_ok=True) state_Dict = { 'Model': self.model.state_dict(), 'Optimizer': self.optimizer.state_dict(), 'Scheduler': self.scheduler.state_dict(), 'Steps': self.steps } torch.save( state_Dict, os.path.join(self.hp.Checkpoint_Path, 'S_{}.pt'.format(self.steps).replace('\\', '/'))) logging.info('Checkpoint saved at {} steps.'.format(self.steps)) def _Set_Distribution(self): if self.num_gpus > 1: self.model = apply_gradient_allreduce(self.model) def Train(self): hp_path = os.path.join(self.hp.Checkpoint_Path, 'Hyper_Parameters.yaml').replace('\\', '/') if not os.path.exists(hp_path): os.makedirs(self.hp.Checkpoint_Path, exist_ok=True) yaml.dump(self.hp, open(hp_path, 'w')) if self.steps == 0: self.Evaluation_Epoch() if self.hp.Train.Initial_Inference: self.Inference_Epoch() self.tqdm = tqdm(initial=self.steps, total=self.hp.Train.Max_Step, desc='[Training]') while self.steps < self.hp.Train.Max_Step: try: self.Train_Epoch() except KeyboardInterrupt: self.Save_Checkpoint() exit(1) self.tqdm.close() logging.info('Finished training.')