def __init__(self, args): self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.mini_batch_size = args.batch_size self.device = args.device self.vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate self.data_dir = args.data_dir self.source_id = args.source_id self.save_dir = args.save_dir self.saver = ModelSaver(args) # Generator self.generator_A2B = Generator().to(self.device) # Load from previous ckpt self.saver.load_model(self.generator_A2B, "generator_A2B", args.ckpt_path, None, None) voc_wav_files = self.read_manifest(dataset="voc", speaker_id=self.source_id) print(f'Found {len(voc_wav_files)} wav files') self.dataset_A, self.dataset_A_mean, self.dataset_A_std = self.normalize_mel( voc_wav_files, self.data_dir, sr=self.sample_rate) self.n_samples = len(self.dataset_A) print(f'n_samples = {self.n_samples}')
def __init__(self, args): """ Args: args (Namespace): Program arguments from argparser """ # Store Args self.device = args.device self.converted_audio_dir = os.path.join(args.save_dir, args.name, 'converted_audio') os.makedirs(self.converted_audio_dir, exist_ok=True) self.model_name = args.model_name self.speaker_A_id = args.speaker_A_id self.speaker_B_id = args.speaker_B_id # Initialize MelGAN-Vocoder used to decode Mel-spectrograms self.vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate # Initialize speakerA's dataset self.dataset_A = self.loadPickleFile( os.path.join(args.preprocessed_data_dir, self.speaker_A_id, f"{self.speaker_A_id}_normalized.pickle")) dataset_A_norm_stats = np.load( os.path.join(args.preprocessed_data_dir, self.speaker_A_id, f"{self.speaker_A_id}_norm_stat.npz")) self.dataset_A_mean = dataset_A_norm_stats['mean'] self.dataset_A_std = dataset_A_norm_stats['std'] # Initialize speakerB's dataset self.dataset_B = self.loadPickleFile( os.path.join(args.preprocessed_data_dir, self.speaker_B_id, f"{self.speaker_B_id}_normalized.pickle")) dataset_B_norm_stats = np.load( os.path.join(args.preprocessed_data_dir, self.speaker_B_id, f"{self.speaker_B_id}_norm_stat.npz")) self.dataset_B_mean = dataset_B_norm_stats['mean'] self.dataset_B_std = dataset_B_norm_stats['std'] source_dataset = self.dataset_A if self.model_name == 'generator_A2B' else self.dataset_B self.dataset = VCDataset(datasetA=source_dataset, datasetB=None, valid=True) self.test_dataloader = torch.utils.data.DataLoader( dataset=self.dataset, batch_size=1, shuffle=False, drop_last=False) # Generator self.generator_A2B = Generator().to(self.device) # Load Generator from ckpt self.saver = ModelSaver(args) self.saver.load_model(self.generator_A2B, self.model_name)
def __init__(self, args): """ Args: args (Namespace): Program arguments from argparser """ # Store args self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.generator_lr = args.generator_lr self.discriminator_lr = args.discriminator_lr self.decay_after = args.decay_after self.mini_batch_size = args.batch_size self.cycle_loss_lambda = args.cycle_loss_lambda self.identity_loss_lambda = args.identity_loss_lambda self.device = args.device self.epochs_per_save = args.epochs_per_save self.epochs_per_plot = args.epochs_per_plot # Initialize MelGAN-Vocoder used to decode Mel-spectrograms self.vocoder = torch.hub.load( 'descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate # Initialize speakerA's dataset self.dataset_A = self.loadPickleFile(os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_normalized.pickle")) dataset_A_norm_stats = np.load(os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_norm_stat.npz")) self.dataset_A_mean = dataset_A_norm_stats['mean'] self.dataset_A_std = dataset_A_norm_stats['std'] # Initialize speakerB's dataset self.dataset_B = self.loadPickleFile(os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_normalized.pickle")) dataset_B_norm_stats = np.load(os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_norm_stat.npz")) self.dataset_B_mean = dataset_B_norm_stats['mean'] self.dataset_B_std = dataset_B_norm_stats['std'] # Compute lr decay rate self.n_samples = len(self.dataset_A) print(f'n_samples = {self.n_samples}') self.generator_lr_decay = self.generator_lr / \ float(self.num_epochs * (self.n_samples // self.mini_batch_size)) self.discriminator_lr_decay = self.discriminator_lr / \ float(self.num_epochs * (self.n_samples // self.mini_batch_size)) print(f'generator_lr_decay = {self.generator_lr_decay}') print(f'discriminator_lr_decay = {self.discriminator_lr_decay}') # Initialize Train Dataloader self.num_frames = args.num_frames self.dataset = VCDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames, max_mask_len=args.max_mask_len) self.train_dataloader = torch.utils.data.DataLoader(dataset=self.dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False) # Initialize Validation Dataloader (used to generate intermediate outputs) self.validation_dataset = VCDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames_validation, max_mask_len=args.max_mask_len, valid=True) self.validation_dataloader = torch.utils.data.DataLoader(dataset=self.validation_dataset, batch_size=1, shuffle=False, drop_last=False) # Initialize logger and saver objects self.logger = TrainLogger(args, len(self.train_dataloader.dataset)) self.saver = ModelSaver(args) # Initialize Generators and Discriminators self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_A2 = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_B2 = Discriminator().to(self.device) # Initialize Optimizers g_params = list(self.generator_A2B.parameters()) + \ list(self.generator_B2A.parameters()) d_params = list(self.discriminator_A.parameters()) + \ list(self.discriminator_B.parameters()) + \ list(self.discriminator_A2.parameters()) + \ list(self.discriminator_B2.parameters()) self.generator_optimizer = torch.optim.Adam( g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = torch.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) # Load from previous ckpt if args.continue_train: self.saver.load_model( self.generator_A2B, "generator_A2B", None, self.generator_optimizer) self.saver.load_model(self.generator_B2A, "generator_B2A", None, None) self.saver.load_model(self.discriminator_A, "discriminator_A", None, self.discriminator_optimizer) self.saver.load_model(self.discriminator_B, "discriminator_B", None, None) self.saver.load_model(self.discriminator_A2, "discriminator_A2", None, None) self.saver.load_model(self.discriminator_B2, "discriminator_B2", None, None)
class MaskCycleGANVCTraining(object): """Trainer for MaskCycleGAN-VC """ def __init__(self, args): """ Args: args (Namespace): Program arguments from argparser """ # Store args self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.generator_lr = args.generator_lr self.discriminator_lr = args.discriminator_lr self.decay_after = args.decay_after self.mini_batch_size = args.batch_size self.cycle_loss_lambda = args.cycle_loss_lambda self.identity_loss_lambda = args.identity_loss_lambda self.device = args.device self.epochs_per_save = args.epochs_per_save self.epochs_per_plot = args.epochs_per_plot # Initialize MelGAN-Vocoder used to decode Mel-spectrograms self.vocoder = torch.hub.load( 'descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate # Initialize speakerA's dataset self.dataset_A = self.loadPickleFile(os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_normalized.pickle")) dataset_A_norm_stats = np.load(os.path.join( args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_norm_stat.npz")) self.dataset_A_mean = dataset_A_norm_stats['mean'] self.dataset_A_std = dataset_A_norm_stats['std'] # Initialize speakerB's dataset self.dataset_B = self.loadPickleFile(os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_normalized.pickle")) dataset_B_norm_stats = np.load(os.path.join( args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_norm_stat.npz")) self.dataset_B_mean = dataset_B_norm_stats['mean'] self.dataset_B_std = dataset_B_norm_stats['std'] # Compute lr decay rate self.n_samples = len(self.dataset_A) print(f'n_samples = {self.n_samples}') self.generator_lr_decay = self.generator_lr / \ float(self.num_epochs * (self.n_samples // self.mini_batch_size)) self.discriminator_lr_decay = self.discriminator_lr / \ float(self.num_epochs * (self.n_samples // self.mini_batch_size)) print(f'generator_lr_decay = {self.generator_lr_decay}') print(f'discriminator_lr_decay = {self.discriminator_lr_decay}') # Initialize Train Dataloader self.num_frames = args.num_frames self.dataset = VCDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames, max_mask_len=args.max_mask_len) self.train_dataloader = torch.utils.data.DataLoader(dataset=self.dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False) # Initialize Validation Dataloader (used to generate intermediate outputs) self.validation_dataset = VCDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames_validation, max_mask_len=args.max_mask_len, valid=True) self.validation_dataloader = torch.utils.data.DataLoader(dataset=self.validation_dataset, batch_size=1, shuffle=False, drop_last=False) # Initialize logger and saver objects self.logger = TrainLogger(args, len(self.train_dataloader.dataset)) self.saver = ModelSaver(args) # Initialize Generators and Discriminators self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_A2 = Discriminator().to(self.device) # Discriminator to compute 2 step adversarial loss self.discriminator_B2 = Discriminator().to(self.device) # Initialize Optimizers g_params = list(self.generator_A2B.parameters()) + \ list(self.generator_B2A.parameters()) d_params = list(self.discriminator_A.parameters()) + \ list(self.discriminator_B.parameters()) + \ list(self.discriminator_A2.parameters()) + \ list(self.discriminator_B2.parameters()) self.generator_optimizer = torch.optim.Adam( g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = torch.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) # Load from previous ckpt if args.continue_train: self.saver.load_model( self.generator_A2B, "generator_A2B", None, self.generator_optimizer) self.saver.load_model(self.generator_B2A, "generator_B2A", None, None) self.saver.load_model(self.discriminator_A, "discriminator_A", None, self.discriminator_optimizer) self.saver.load_model(self.discriminator_B, "discriminator_B", None, None) self.saver.load_model(self.discriminator_A2, "discriminator_A2", None, None) self.saver.load_model(self.discriminator_B2, "discriminator_B2", None, None) def adjust_lr_rate(self, optimizer, generator): """Decays learning rate. Args: optimizer (torch.optim): torch optimizer generator (bool): Whether to adjust generator lr. """ if generator: self.generator_lr = max( 0., self.generator_lr - self.generator_lr_decay) for param_groups in optimizer.param_groups: param_groups['lr'] = self.generator_lr else: self.discriminator_lr = max( 0., self.discriminator_lr - self.discriminator_lr_decay) for param_groups in optimizer.param_groups: param_groups['lr'] = self.discriminator_lr def reset_grad(self): """Sets gradients of the generators and discriminators to zero before backpropagation. """ self.generator_optimizer.zero_grad() self.discriminator_optimizer.zero_grad() def loadPickleFile(self, fileName): """Loads a Pickle file. Args: fileName (str): pickle file path Returns: file object: The loaded pickle file object """ with open(fileName, 'rb') as f: return pickle.load(f) def train(self): """Implements the training loop for MaskCycleGAN-VC """ for epoch in range(self.start_epoch, self.num_epochs + 1): self.logger.start_epoch() for i, (real_A, mask_A, real_B, mask_B) in enumerate(tqdm(self.train_dataloader)): self.logger.start_iter() num_iterations = ( self.n_samples // self.mini_batch_size) * epoch + i real_A = real_A.to(self.device, dtype=torch.float) mask_A = mask_A.to(self.device, dtype=torch.float) real_B = real_B.to(self.device, dtype=torch.float) mask_B = mask_B.to(self.device, dtype=torch.float) # Train Generator # Generator Feed Forward fake_B = self.generator_A2B(real_A, mask_A) cycle_A = self.generator_B2A(fake_B, torch.ones_like(fake_B)) fake_A = self.generator_B2A(real_B, mask_B) cycle_B = self.generator_A2B(fake_A, torch.ones_like(fake_A)) identity_A = self.generator_B2A( real_A, torch.ones_like(real_A)) identity_B = self.generator_A2B( real_B, torch.ones_like(real_B)) d_fake_A = self.discriminator_A(fake_A) d_fake_B = self.discriminator_B(fake_B) # For Two Step Adverserial Loss d_fake_cycle_A = self.discriminator_A2(cycle_A) d_fake_cycle_B = self.discriminator_B2(cycle_B) # Generator Cycle Loss cycleLoss = torch.mean( torch.abs(real_A - cycle_A)) + torch.mean(torch.abs(real_B - cycle_B)) # Generator Identity Loss identityLoss = torch.mean( torch.abs(real_A - identity_A)) + torch.mean(torch.abs(real_B - identity_B)) # Generator Loss g_loss_A2B = torch.mean((1 - d_fake_B) ** 2) g_loss_B2A = torch.mean((1 - d_fake_A) ** 2) # Generator Two Step Adverserial Loss generator_loss_A2B_2nd = torch.mean((1 - d_fake_cycle_B) ** 2) generator_loss_B2A_2nd = torch.mean((1 - d_fake_cycle_A) ** 2) # Total Generator Loss g_loss = g_loss_A2B + g_loss_B2A + \ generator_loss_A2B_2nd + generator_loss_B2A_2nd + \ self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss # Backprop for Generator self.reset_grad() g_loss.backward() self.generator_optimizer.step() # Train Discriminator # Discriminator Feed Forward d_real_A = self.discriminator_A(real_A) d_real_B = self.discriminator_B(real_B) d_real_A2 = self.discriminator_A2(real_A) d_real_B2 = self.discriminator_B2(real_B) generated_A = self.generator_B2A(real_B, mask_B) d_fake_A = self.discriminator_A(generated_A) # For Two Step Adverserial Loss A->B cycled_B = self.generator_A2B( generated_A, torch.ones_like(generated_A)) d_cycled_B = self.discriminator_B2(cycled_B) generated_B = self.generator_A2B(real_A, mask_A) d_fake_B = self.discriminator_B(generated_B) # For Two Step Adverserial Loss B->A cycled_A = self.generator_B2A( generated_B, torch.ones_like(generated_B)) d_cycled_A = self.discriminator_A2(cycled_A) # Loss Functions d_loss_A_real = torch.mean((1 - d_real_A) ** 2) d_loss_A_fake = torch.mean((0 - d_fake_A) ** 2) d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0 d_loss_B_real = torch.mean((1 - d_real_B) ** 2) d_loss_B_fake = torch.mean((0 - d_fake_B) ** 2) d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0 # Two Step Adverserial Loss d_loss_A_cycled = torch.mean((0 - d_cycled_A) ** 2) d_loss_B_cycled = torch.mean((0 - d_cycled_B) ** 2) d_loss_A2_real = torch.mean((1 - d_real_A2) ** 2) d_loss_B2_real = torch.mean((1 - d_real_B2) ** 2) d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0 d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0 # Final Loss for discriminator with the Two Step Adverserial Loss d_loss = (d_loss_A + d_loss_B) / 2.0 + \ (d_loss_A_2nd + d_loss_B_2nd) / 2.0 # Backprop for Discriminator self.reset_grad() d_loss.backward() self.discriminator_optimizer.step() # Log Iteration self.logger.log_iter( loss_dict={'g_loss': g_loss.item(), 'd_loss': d_loss.item()}) self.logger.end_iter() # Adjust learning rates if self.logger.global_step > self.decay_after: self.identity_loss_lambda = 0 self.adjust_lr_rate( self.generator_optimizer, generator=True) self.adjust_lr_rate( self.generator_optimizer, generator=False) # Log intermediate outputs on Tensorboard if self.logger.epoch % self.epochs_per_plot == 0: # Log Mel-spectrograms .png real_mel_A_fig = get_mel_spectrogram_fig( real_A[0].detach().cpu()) fake_mel_A_fig = get_mel_spectrogram_fig( generated_A[0].detach().cpu()) real_mel_B_fig = get_mel_spectrogram_fig( real_B[0].detach().cpu()) fake_mel_B_fig = get_mel_spectrogram_fig( generated_B[0].detach().cpu()) self.logger.visualize_outputs({"real_voc_spec": real_mel_A_fig, "fake_coraal_spec": fake_mel_B_fig, "real_coraal_spec": real_mel_B_fig, "fake_voc_spec": fake_mel_A_fig}) # Convert Mel-spectrograms from validation set to waveform and log to tensorboard real_mel_full_A, real_mel_full_B = next( iter(self.validation_dataloader)) real_mel_full_A = real_mel_full_A.to( self.device, dtype=torch.float) real_mel_full_B = real_mel_full_B.to( self.device, dtype=torch.float) fake_mel_full_B = self.generator_A2B( real_mel_full_A, torch.ones_like(real_mel_full_A)) fake_mel_full_A = self.generator_B2A( real_mel_full_B, torch.ones_like(real_mel_full_B)) real_wav_full_A = decode_melspectrogram(self.vocoder, real_mel_full_A[0].detach( ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() fake_wav_full_A = decode_melspectrogram(self.vocoder, fake_mel_full_A[0].detach( ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() real_wav_full_B = decode_melspectrogram(self.vocoder, real_mel_full_B[0].detach( ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() fake_wav_full_B = decode_melspectrogram(self.vocoder, fake_mel_full_B[0].detach( ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() self.logger.log_audio( real_wav_full_A.T, "real_speaker_A_audio", self.sample_rate) self.logger.log_audio( fake_wav_full_A.T, "fake_speaker_A_audio", self.sample_rate) self.logger.log_audio( real_wav_full_B.T, "real_speaker_B_audio", self.sample_rate) self.logger.log_audio( fake_wav_full_B.T, "fake_speaker_B_audio", self.sample_rate) # Save each model checkpoint if self.logger.epoch % self.epochs_per_save == 0: self.saver.save(self.logger.epoch, self.generator_A2B, self.generator_optimizer, None, args.device, "generator_A2B") self.saver.save(self.logger.epoch, self.generator_B2A, self.generator_optimizer, None, args.device, "generator_B2A") self.saver.save(self.logger.epoch, self.discriminator_A, self.discriminator_optimizer, None, args.device, "discriminator_A") self.saver.save(self.logger.epoch, self.discriminator_B, self.discriminator_optimizer, None, args.device, "discriminator_B") self.saver.save(self.logger.epoch, self.discriminator_A2, self.discriminator_optimizer, None, args.device, "discriminator_A2") self.saver.save(self.logger.epoch, self.discriminator_B2, self.discriminator_optimizer, None, args.device, "discriminator_B2") self.logger.end_epoch()
class CycleGANGenerate(object): def __init__(self, args): self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.mini_batch_size = args.batch_size self.device = args.device self.vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate self.data_dir = args.data_dir self.source_id = args.source_id self.save_dir = args.save_dir self.saver = ModelSaver(args) # Generator self.generator_A2B = Generator().to(self.device) # Load from previous ckpt self.saver.load_model(self.generator_A2B, "generator_A2B", args.ckpt_path, None, None) voc_wav_files = self.read_manifest(dataset="voc", speaker_id=self.source_id) print(f'Found {len(voc_wav_files)} wav files') self.dataset_A, self.dataset_A_mean, self.dataset_A_std = self.normalize_mel( voc_wav_files, self.data_dir, sr=self.sample_rate) self.n_samples = len(self.dataset_A) print(f'n_samples = {self.n_samples}') def read_manifest(self, split=None, dataset=None, speaker_id=None): # Load manifest file which defines dataset manifest_path = os.path.join('./manifests', f'{dataset}_manifest.csv') df = pd.read_csv(manifest_path, sep=',') # Filter by speaker_id df['speaker_id'] = df['speaker_id'].astype(str) df = df[df['speaker_id'] == speaker_id] wav_files = df['wav_file'].tolist() return wav_files def normalize_mel(self, wav_files, data_dir, sr=22050): vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') mel_list = dict() for wavpath in tqdm(wav_files, desc='Preprocess wav to mel'): wav_orig, _ = librosa.load(os.path.join(data_dir, wavpath), sr=sr, mono=True) spec = vocoder(torch.tensor([wav_orig])) assert wavpath not in mel_list mel_list[wavpath] = spec.cpu().detach().numpy()[0] mel_concatenated = np.concatenate(list(mel_list.values()), axis=1) mel_mean = np.mean(mel_concatenated, axis=1, keepdims=True) mel_std = np.std(mel_concatenated, axis=1, keepdims=True) + 1e-9 mel_normalized = dict() for wavpath, mel in mel_list.items(): app = (mel - mel_mean) / mel_std assert wavpath not in mel_normalized mel_normalized[wavpath] = app return mel_normalized, mel_mean, mel_std def save_pickle(self, variable, fileName): with open(fileName, 'wb') as f: pickle.dump(variable, f) def run(self): converted_specs = dict() for i, (wavpath, melspec) in enumerate(tqdm(self.dataset_A.items())): real_A = torch.tensor(melspec).unsqueeze(0).to(self.device, dtype=torch.float) fake_B_normalized = self.generator_A2B( real_A, torch.ones_like(real_A)).squeeze(0).detach().cpu().numpy() fake_B = fake_B_normalized * self.dataset_A_std + self.dataset_A_mean converted_specs[wavpath] = fake_B print( f"Saving to ~/data/converted/voc_converted_{self.source_id}.pickle" ) self.save_pickle(variable=converted_specs, fileName=os.path.join( '/home/ubuntu/data', "converted", f"voc_converted_{self.source_id}.pickle"))
def SaveSubgraph(option, subg): saver = ModelSaver(subg) if option.save_config == True: saver.SaveConfigInfo(option.save_prefix)
class MaskCycleGANVCTesting(object): """Tester for MaskCycleGAN-VC """ def __init__(self, args): """ Args: args (Namespace): Program arguments from argparser """ # Store Args self.device = args.device self.converted_audio_dir = os.path.join(args.save_dir, args.name, 'converted_audio') os.makedirs(self.converted_audio_dir, exist_ok=True) self.model_name = args.model_name self.speaker_A_id = args.speaker_A_id self.speaker_B_id = args.speaker_B_id # Initialize MelGAN-Vocoder used to decode Mel-spectrograms self.vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate # Initialize speakerA's dataset self.dataset_A = self.loadPickleFile( os.path.join(args.preprocessed_data_dir, self.speaker_A_id, f"{self.speaker_A_id}_normalized.pickle")) dataset_A_norm_stats = np.load( os.path.join(args.preprocessed_data_dir, self.speaker_A_id, f"{self.speaker_A_id}_norm_stat.npz")) self.dataset_A_mean = dataset_A_norm_stats['mean'] self.dataset_A_std = dataset_A_norm_stats['std'] # Initialize speakerB's dataset self.dataset_B = self.loadPickleFile( os.path.join(args.preprocessed_data_dir, self.speaker_B_id, f"{self.speaker_B_id}_normalized.pickle")) dataset_B_norm_stats = np.load( os.path.join(args.preprocessed_data_dir, self.speaker_B_id, f"{self.speaker_B_id}_norm_stat.npz")) self.dataset_B_mean = dataset_B_norm_stats['mean'] self.dataset_B_std = dataset_B_norm_stats['std'] source_dataset = self.dataset_A if self.model_name == 'generator_A2B' else self.dataset_B self.dataset = VCDataset(datasetA=source_dataset, datasetB=None, valid=True) self.test_dataloader = torch.utils.data.DataLoader( dataset=self.dataset, batch_size=1, shuffle=False, drop_last=False) # Generator self.generator_A2B = Generator().to(self.device) # Load Generator from ckpt self.saver = ModelSaver(args) self.saver.load_model(self.generator_A2B, self.model_name) def loadPickleFile(self, fileName): """Loads a Pickle file. Args: fileName (str): pickle file path Returns: file object: The loaded pickle file object """ with open(fileName, 'rb') as f: return pickle.load(f) def test(self): for i, (real_A) in enumerate(tqdm(self.test_dataloader)): real_A = real_A.to(self.device, dtype=torch.float) fake_B = self.generator_A2B(real_A, torch.ones_like(real_A)) wav_fake_B = decode_melspectrogram(self.vocoder, fake_B[0].detach().cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() save_path = None if self.model_name == 'generator_A2B': save_path = os.path.join( self.converted_audio_dir, f"converted_{self.speaker_A_id}_to_{self.speaker_B_id}{i}.wav" ) else: save_path = os.path.join( self.converted_audio_dir, f"converted_{self.speaker_B_id}_to_{self.speaker_A_id}{i}.wav" ) torchaudio.save(save_path, wav_fake_B, sample_rate=self.sample_rate)
def train(args): """ Implements the training loop for the MultiTaskResnet3dClassifier. Args: args (Namespace) : Program arguments """ # Get model and loss function model = MTClassifier3D(args).to(args.device) # Initialize losses for each head loss_wrapper = MultiTaskLoss(args) loss_fn = nn.BCEWithLogitsLoss() # TODO: Get train and validation dataloaders train_dataset = ClassifierDataset(args.csv_dir, 'train', args.features, resample=( args.num_slices, args.slice_size, args.slice_size)) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True ) peds_validation_dataset = ClassifierDataset(args.peds_csv_dir, 'val', args.peds_features, resample=( args.num_slices, args.slice_size, args.slice_size)) peds_validation_loader = DataLoader( peds_validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True ) adult_validation_dataset = ClassifierDataset(args.adult_csv_dir, 'val', args.adult_features, resample=( args.num_slices, args.slice_size, args.slice_size)) adult_validation_loader = DataLoader( adult_validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True ) # Get optimizer and scheduler optimizer = optim.Adam(model.parameters(), args.lr) warmup_iters = args.lr_warmup_epochs * len(train_loader) lr_milestones = [len(train_loader) * m for m in args.lr_milestones] lr_scheduler = WarmupMultiStepLR( optimizer, milestones=lr_milestones, gamma=args.lr_gamma, warmup_iters=warmup_iters, warmup_factor=1e-5) # Get saver, logger, and evaluator saver = ModelSaver(args, max_ckpts=args.max_ckpts, metric_name=args.best_ckpt_metric, maximize_metric=args.maximize_metric) # evaluator = ModelEvaluator(args, validation_loader, cls_loss_fn) # Load model from checkpoint is applicable if args.continue_train: saver.load_model(model, args.name, ckpt_path=args.load_path, optimizer=optimizer, scheduler=lr_scheduler) logger = TrainLogger(args, len(train_loader.dataset)) # Multi GPU training if applicable if len(args.gpu_ids) > 1: print("Using", len(args.gpu_ids), "GPUs.") model = nn.DataParallel(model) loss_meter = meter.AverageValueMeter() # Train model logger.log_hparams(args) while not logger.is_finished_training(): logger.start_epoch() for inputs, targets in tqdm(train_loader): logger.start_iter() with torch.set_grad_enabled(True): inputs = inputs.to(args.device) targets = targets.to(args.device) head_preds = model(inputs) loss = loss_wrapper(head_preds, targets) loss_meter.add(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() # Log all train losses if logger.iter % args.steps_per_print == 0 and logger.iter != 0: logger.log_metrics({'train_loss': loss_meter.value()[0]}) loss_meter.reset() logger.end_iter() # Evaluate model and save model ckpt if logger.epoch % args.epochs_per_eval == 0: peds_metrics = evaluate(args, model, loss_wrapper, peds_validation_loader, "validation", args.device, 'peds') logger.log_metrics(peds_metrics) adult_metrics = evaluate(args, model, loss_wrapper, adult_validation_loader, "validation", args.device, 'adult') logger.log_metrics(adult_metrics) if logger.epoch % args.epochs_per_save == 0: saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device, args.name) lr_scheduler.step() logger.end_epoch()
def main(args): if args.librispeech: print("Loading Librispeech dataset!") train_dataset = torchaudio.datasets.LIBRISPEECH( args.data_dir, url="train-clean-360", download=True) valid_dataset = torchaudio.datasets.LIBRISPEECH( args.data_dir, url="test-clean", download=True) else: train_dataset = Dataset(args, "train", return_pair=args.return_pair) valid_dataset = Dataset(args, "val", return_pair=args.return_pair) print(f"Training set has {len(train_dataset)} samples. Validation set has {len(valid_dataset)} samples.") # train_audio_transforms = get_audio_transforms('train') # valid_audio_transforms = get_audio_transforms('valid') text_transform = TextTransform() train_loader = data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=lambda x: data_processing( x, "train", text_transform), num_workers=args.num_workers, pin_memory=True) valid_loader = data.DataLoader(dataset=valid_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=lambda x: data_processing( x, "valid", text_transform), num_workers=args.num_workers, pin_memory=True) model = SpeechRecognitionModel( args.n_cnn_layers, args.n_rnn_layers, args.rnn_dim, args.n_class, args.n_feats, args.stride, args.dropout ).to(args.device) print('Num Model Parameters', sum( [param.nelement() for param in model.parameters()])) optimizer = optim.AdamW(model.parameters(), args.lr) criterion = nn.CTCLoss(blank=28).to(args.device) scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, steps_per_epoch=int( len(train_loader)), epochs=args.num_epochs, anneal_strategy='linear') # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.gamma) saver = ModelSaver(args, max_ckpts=args.max_ckpts, metric_name="test_wer", maximize_metric=False) if args.continue_train: saver.load_model(model, "SpeechRecognitionModel", args.ckpt_path, optimizer, scheduler) elif args.pretrained_ckpt_path: saver.load_model(model, "SpeechRecognitionModel", args.pretrained_ckpt_path, None, None) if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs!") model = nn.DataParallel(model) logger = TrainLogger(args, len(train_loader.dataset)) logger.log_hparams(args) for epoch in range(args.start_epoch, args.num_epochs + 1): train(args, model, train_loader, criterion, optimizer, scheduler, logger) if logger.epoch % args.epochs_per_save == 0: metric_dict = test(args, model, valid_loader, criterion, logger) saver.save(logger.epoch, model, optimizer, scheduler, args.device, "SpeechRecognitionModel", metric_dict["test_wer"]) logger.end_epoch()
def __init__(self, args): self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.generator_lr = args.generator_lr self.discriminator_lr = args.discriminator_lr self.decay_after = args.decay_after self.mini_batch_size = args.batch_size self.cycle_loss_lambda = args.cycle_loss_lambda self.identity_loss_lambda = args.identity_loss_lambda self.device = args.device self.epochs_per_save = args.epochs_per_save self.epochs_per_plot = args.epochs_per_plot self.vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate self.dataset_A = self.loadPickleFile(args.normalized_dataset_A_path) dataset_A_norm_stats = np.load(args.norm_stats_A_path) # TODO: fix to mean and std after running data preprocessing script again self.dataset_A_mean = dataset_A_norm_stats['mean'] self.dataset_A_std = dataset_A_norm_stats['std'] self.dataset_B = self.loadPickleFile(args.normalized_dataset_B_path) dataset_B_norm_stats = np.load(args.norm_stats_B_path) self.dataset_B_mean = dataset_B_norm_stats['mean'] self.dataset_B_std = dataset_B_norm_stats['std'] self.n_samples = len(self.dataset_A) print(f'n_samples = {self.n_samples}') self.generator_lr_decay = self.generator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) self.discriminator_lr_decay = self.discriminator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) print(f'generator_lr_decay = {self.generator_lr_decay}') print(f'discriminator_lr_decay = {self.discriminator_lr_decay}') self.num_frames = args.num_frames self.dataset = trainingDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames, max_mask_len=args.max_mask_len) self.train_dataloader = torch.utils.data.DataLoader( dataset=self.dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False) self.validation_dataset = trainingDataset( datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames_validation, max_mask_len=args.max_mask_len, valid=True) self.validation_dataloader = torch.utils.data.DataLoader( dataset=self.validation_dataset, batch_size=1, shuffle=False, drop_last=False) self.logger = TrainLogger(args, len(self.train_dataloader.dataset)) self.saver = ModelSaver(args) # Generator and Discriminator self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) self.discriminator_A2 = Discriminator().to(self.device) self.discriminator_B2 = Discriminator().to(self.device) # Optimizer g_params = list(self.generator_A2B.parameters()) + \ list(self.generator_B2A.parameters()) d_params = list(self.discriminator_A.parameters()) + \ list(self.discriminator_B.parameters()) + \ list(self.discriminator_A2.parameters()) + \ list(self.discriminator_B2.parameters()) self.generator_optimizer = torch.optim.Adam(g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = torch.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) # Load from previous ckpt if args.continue_train: self.saver.load_model(self.generator_A2B, "generator_A2B", None, self.generator_optimizer) self.saver.load_model(self.generator_B2A, "generator_B2A", None, None) self.saver.load_model(self.discriminator_A, "discriminator_A", None, self.discriminator_optimizer) self.saver.load_model(self.discriminator_B, "discriminator_B", None, None) self.saver.load_model(self.discriminator_A2, "discriminator_A2", None, None) self.saver.load_model(self.discriminator_B2, "discriminator_B2", None, None)
class CycleGANTraining(object): def __init__(self, args): self.num_epochs = args.num_epochs self.start_epoch = args.start_epoch self.generator_lr = args.generator_lr self.discriminator_lr = args.discriminator_lr self.decay_after = args.decay_after self.mini_batch_size = args.batch_size self.cycle_loss_lambda = args.cycle_loss_lambda self.identity_loss_lambda = args.identity_loss_lambda self.device = args.device self.epochs_per_save = args.epochs_per_save self.epochs_per_plot = args.epochs_per_plot self.vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan') self.sample_rate = args.sample_rate self.dataset_A = self.loadPickleFile(args.normalized_dataset_A_path) dataset_A_norm_stats = np.load(args.norm_stats_A_path) # TODO: fix to mean and std after running data preprocessing script again self.dataset_A_mean = dataset_A_norm_stats['mean'] self.dataset_A_std = dataset_A_norm_stats['std'] self.dataset_B = self.loadPickleFile(args.normalized_dataset_B_path) dataset_B_norm_stats = np.load(args.norm_stats_B_path) self.dataset_B_mean = dataset_B_norm_stats['mean'] self.dataset_B_std = dataset_B_norm_stats['std'] self.n_samples = len(self.dataset_A) print(f'n_samples = {self.n_samples}') self.generator_lr_decay = self.generator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) self.discriminator_lr_decay = self.discriminator_lr / float( self.num_epochs * (self.n_samples // self.mini_batch_size)) print(f'generator_lr_decay = {self.generator_lr_decay}') print(f'discriminator_lr_decay = {self.discriminator_lr_decay}') self.num_frames = args.num_frames self.dataset = trainingDataset(datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames, max_mask_len=args.max_mask_len) self.train_dataloader = torch.utils.data.DataLoader( dataset=self.dataset, batch_size=self.mini_batch_size, shuffle=True, drop_last=False) self.validation_dataset = trainingDataset( datasetA=self.dataset_A, datasetB=self.dataset_B, n_frames=args.num_frames_validation, max_mask_len=args.max_mask_len, valid=True) self.validation_dataloader = torch.utils.data.DataLoader( dataset=self.validation_dataset, batch_size=1, shuffle=False, drop_last=False) self.logger = TrainLogger(args, len(self.train_dataloader.dataset)) self.saver = ModelSaver(args) # Generator and Discriminator self.generator_A2B = Generator().to(self.device) self.generator_B2A = Generator().to(self.device) self.discriminator_A = Discriminator().to(self.device) self.discriminator_B = Discriminator().to(self.device) self.discriminator_A2 = Discriminator().to(self.device) self.discriminator_B2 = Discriminator().to(self.device) # Optimizer g_params = list(self.generator_A2B.parameters()) + \ list(self.generator_B2A.parameters()) d_params = list(self.discriminator_A.parameters()) + \ list(self.discriminator_B.parameters()) + \ list(self.discriminator_A2.parameters()) + \ list(self.discriminator_B2.parameters()) self.generator_optimizer = torch.optim.Adam(g_params, lr=self.generator_lr, betas=(0.5, 0.999)) self.discriminator_optimizer = torch.optim.Adam( d_params, lr=self.discriminator_lr, betas=(0.5, 0.999)) # Load from previous ckpt if args.continue_train: self.saver.load_model(self.generator_A2B, "generator_A2B", None, self.generator_optimizer) self.saver.load_model(self.generator_B2A, "generator_B2A", None, None) self.saver.load_model(self.discriminator_A, "discriminator_A", None, self.discriminator_optimizer) self.saver.load_model(self.discriminator_B, "discriminator_B", None, None) self.saver.load_model(self.discriminator_A2, "discriminator_A2", None, None) self.saver.load_model(self.discriminator_B2, "discriminator_B2", None, None) def adjust_lr_rate(self, optimizer, name='generator'): if name == 'generator': self.generator_lr = max( 0., self.generator_lr - self.generator_lr_decay) for param_groups in optimizer.param_groups: param_groups['lr'] = self.generator_lr else: self.discriminator_lr = max( 0., self.discriminator_lr - self.discriminator_lr_decay) for param_groups in optimizer.param_groups: param_groups['lr'] = self.discriminator_lr def reset_grad(self): self.generator_optimizer.zero_grad() self.discriminator_optimizer.zero_grad() def loadPickleFile(self, fileName): with open(fileName, 'rb') as f: return pickle.load(f) def train(self): for epoch in range(self.start_epoch, self.num_epochs): self.logger.start_epoch() for i, (real_A, mask_A, real_B, mask_B) in enumerate(tqdm(self.train_dataloader)): self.logger.start_iter() num_iterations = (self.n_samples // self.mini_batch_size) * epoch + i real_A = real_A.to(self.device, dtype=torch.float) mask_A = mask_A.to(self.device, dtype=torch.float) real_B = real_B.to(self.device, dtype=torch.float) mask_B = mask_B.to(self.device, dtype=torch.float) # Train Generator fake_B = self.generator_A2B(real_A, mask_A) cycle_A = self.generator_B2A(fake_B, torch.ones_like(fake_B)) fake_A = self.generator_B2A(real_B, mask_B) cycle_B = self.generator_A2B(fake_A, torch.ones_like(fake_A)) identity_A = self.generator_B2A(real_A, torch.ones_like(real_A)) identity_B = self.generator_A2B(real_B, torch.ones_like(real_B)) d_fake_A = self.discriminator_A(fake_A) d_fake_B = self.discriminator_B(fake_B) # for the second step adverserial loss d_fake_cycle_A = self.discriminator_A2(cycle_A) d_fake_cycle_B = self.discriminator_B2(cycle_B) # Generator Cycle Loss cycleLoss = torch.mean( torch.abs(real_A - cycle_A)) + torch.mean( torch.abs(real_B - cycle_B)) # Generator Identity Loss identityLoss = torch.mean( torch.abs(real_A - identity_A)) + torch.mean( torch.abs(real_B - identity_B)) # Generator Loss g_loss_A2B = torch.mean((1 - d_fake_B)**2) g_loss_B2A = torch.mean((1 - d_fake_A)**2) # Generator second step adverserial loss generator_loss_A2B_2nd = torch.mean((1 - d_fake_cycle_B)**2) generator_loss_B2A_2nd = torch.mean((1 - d_fake_cycle_A)**2) # Total Generator Loss g_loss = g_loss_A2B + g_loss_B2A + \ generator_loss_A2B_2nd + generator_loss_B2A_2nd + \ self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss # self.generator_loss_store.append(generator_loss.item()) # Backprop for Generator self.reset_grad() g_loss.backward() self.generator_optimizer.step() # Train Discriminator # Discriminator Feed Forward d_real_A = self.discriminator_A(real_A) d_real_B = self.discriminator_B(real_B) d_real_A2 = self.discriminator_A2(real_A) d_real_B2 = self.discriminator_B2(real_B) generated_A = self.generator_B2A(real_B, mask_B) d_fake_A = self.discriminator_A(generated_A) # For Second Step Adverserial Loss A->B cycled_B = self.generator_A2B(generated_A, torch.ones_like(generated_A)) d_cycled_B = self.discriminator_B2(cycled_B) generated_B = self.generator_A2B(real_A, mask_A) d_fake_B = self.discriminator_B(generated_B) # For Second Step Adverserial Loss B->A cycled_A = self.generator_B2A(generated_B, torch.ones_like(generated_B)) d_cycled_A = self.discriminator_A2(cycled_A) # Loss Functions d_loss_A_real = torch.mean((1 - d_real_A)**2) d_loss_A_fake = torch.mean((0 - d_fake_A)**2) d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0 d_loss_B_real = torch.mean((1 - d_real_B)**2) d_loss_B_fake = torch.mean((0 - d_fake_B)**2) d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0 # Second Step Adverserial Loss d_loss_A_cycled = torch.mean((0 - d_cycled_A)**2) d_loss_B_cycled = torch.mean((0 - d_cycled_B)**2) d_loss_A2_real = torch.mean((1 - d_real_A2)**2) d_loss_B2_real = torch.mean((1 - d_real_B2)**2) d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0 d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0 # Final Loss for discriminator with the second step adverserial loss d_loss = (d_loss_A + d_loss_B) / 2.0 + \ (d_loss_A_2nd + d_loss_B_2nd) / 2.0 # self.discriminator_loss_store.append(d_loss.item()) # Backprop for Discriminator self.reset_grad() d_loss.backward() self.discriminator_optimizer.step() # if num_iterations % args.steps_per_print == 0: # print(f"Epoch: {epoch} Step: {num_iterations} Generator Loss: {generator_loss.item()} Discriminator Loss: {d_loss.item()}") self.logger.log_iter(loss_dict={ 'g_loss': g_loss.item(), 'd_loss': d_loss.item() }) self.logger.end_iter() # adjust learning rates if self.logger.global_step > self.decay_after: self.identity_loss_lambda = 0 self.adjust_lr_rate(self.generator_optimizer, name='generator') self.adjust_lr_rate(self.generator_optimizer, name='discriminator') if self.logger.epoch % self.epochs_per_plot == 0: # Log spectrograms real_mel_A_fig = get_mel_spectrogram_fig( real_A[0].detach().cpu()) fake_mel_A_fig = get_mel_spectrogram_fig( generated_A[0].detach().cpu()) real_mel_B_fig = get_mel_spectrogram_fig( real_B[0].detach().cpu()) fake_mel_B_fig = get_mel_spectrogram_fig( generated_B[0].detach().cpu()) self.logger.visualize_outputs({ "real_voc_spec": real_mel_A_fig, "fake_coraal_spec": fake_mel_B_fig, "real_coraal_spec": real_mel_B_fig, "fake_voc_spec": fake_mel_A_fig }) # Decode spec->wav real_wav_A = decode_melspectrogram(self.vocoder, real_A[0].detach().cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() fake_wav_A = decode_melspectrogram( self.vocoder, generated_A[0].detach().cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() real_wav_B = decode_melspectrogram(self.vocoder, real_B[0].detach().cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() fake_wav_B = decode_melspectrogram( self.vocoder, generated_B[0].detach().cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() # # Log wav # real_wav_A_fig = get_waveform_fig(real_wav_A, self.sample_rate) # fake_wav_A_fig = get_waveform_fig(fake_wav_A, self.sample_rate) # real_wav_B_fig = get_waveform_fig(real_wav_B, self.sample_rate) # fake_wav_B_fig = get_waveform_fig(fake_wav_B, self.sample_rate) # self.logger.visualize_outputs({"real_voc_wav": real_wav_A_fig, "fake_coraal_wav": fake_wav_B_fig, # "real_coraal_wav": real_wav_B_fig, "fake_voc_wav": fake_wav_A_fig}) # Convert spectrograms from validation set to wav and log to tensorboard real_mel_full_A, real_mel_full_B = next( iter(self.validation_dataloader)) real_mel_full_A = real_mel_full_A.to(self.device, dtype=torch.float) real_mel_full_B = real_mel_full_B.to(self.device, dtype=torch.float) fake_mel_full_B = self.generator_A2B( real_mel_full_A, torch.ones_like(real_mel_full_A)) fake_mel_full_A = self.generator_B2A( real_mel_full_B, torch.ones_like(real_mel_full_B)) real_wav_full_A = decode_melspectrogram( self.vocoder, real_mel_full_A[0].detach().cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() fake_wav_full_A = decode_melspectrogram( self.vocoder, fake_mel_full_A[0].detach().cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() real_wav_full_B = decode_melspectrogram( self.vocoder, real_mel_full_B[0].detach().cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() fake_wav_full_B = decode_melspectrogram( self.vocoder, fake_mel_full_B[0].detach().cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() self.logger.log_audio(real_wav_full_A.T, "real_voc_audio", self.sample_rate) self.logger.log_audio(fake_wav_full_A.T, "fake_voc_audio", self.sample_rate) self.logger.log_audio(real_wav_full_B.T, "real_coraal_audio", self.sample_rate) self.logger.log_audio(fake_wav_full_B.T, "fake_coraal_audio", self.sample_rate) if self.logger.epoch % self.epochs_per_save == 0: self.saver.save(self.logger.epoch, self.generator_A2B, self.generator_optimizer, None, args.device, "generator_A2B") self.saver.save(self.logger.epoch, self.generator_B2A, self.generator_optimizer, None, args.device, "generator_B2A") self.saver.save(self.logger.epoch, self.discriminator_A, self.discriminator_optimizer, None, args.device, "discriminator_A") self.saver.save(self.logger.epoch, self.discriminator_B, self.discriminator_optimizer, None, args.device, "discriminator_B") self.saver.save(self.logger.epoch, self.discriminator_A2, self.discriminator_optimizer, None, args.device, "discriminator_A2") self.saver.save(self.logger.epoch, self.discriminator_B2, self.discriminator_optimizer, None, args.device, "discriminator_B2") self.logger.end_epoch()