def __init__(self, n_tiers: int, layers: List[int], hidden_size: int, gmm_size: int, freq: int): """ Args: n_tiers (int): number of tiers the model is composed of layers (List[int]): list with the layers of every tier hidden_size (int): parameter for the hidden_state of the Delayed Stack Layers and other and other sizes gmm_size (int): number of mixture components of the GMM freq (int): size of the frequency axis of the spectrogram to generate. See note in the documentation of the file. """ super(MelNet, self).__init__() self.n_tiers = n_tiers self.layers = layers self.hidden_size = hidden_size self.gmm_size = gmm_size assert freq >= 2 ** (self.n_tiers / 2), "Size of frequency axis is too small for " \ "being generated with the number of tiers " \ "of this model" self.freq = freq self.tiers = nn.ModuleList([ Tier1( tier=1, n_layers=layers[0], hidden_size=hidden_size, gmm_size=gmm_size, # Calculate size of FREQ dimension for this tier freq=tierutil.get_size_freqdim_of_tier( n_mels=self.freq, n_tiers=self.n_tiers, tier=1)) ] + [ Tier( tier=tier_idx, n_layers=layers[tier_idx], hidden_size=hidden_size, gmm_size=gmm_size, # Calculate size of FREQ dimension for this tier freq=tierutil.get_size_freqdim_of_tier( n_mels=self.freq, n_tiers=self.n_tiers, tier=tier_idx + 1)) for tier_idx in range(1, n_tiers) ])
def train_tier(args: argparse.Namespace, hp: HParams, tier: int, extension_architecture: str, timestamp: str, tensorboardwriter: TensorboardWriter, logger: logging.Logger) -> None: """ Trains one tier of MelNet. Args: args (argparse.Namespace): parameters to set up the training. At least, args must contain: args = {"path_config": ..., "tier": ..., "checkpoint_path": ...} hp (HParams): hyperparameters for the model and other parameters (training, dataset, ...) tier (int): number of the tier to train. extension_architecture (str): information about the network's architecture of this run (training) to identify the logs and weights of the model. timestamp (str): information that identifies completely this run (training). tensorboardwriter (TensorboardWriter): to log information about training to tensorboard. logger (logging.Logger): to log general information about the training of the model. """ logger.info(f"Start training of tier {tier}/{hp.network.n_tiers}") # Setup the data ready to be consumed train_dataloader, test_dataloader, num_samples = get_dataloader(hp) # Setup tier # Calculate size of FREQ dimension for this tier tier_freq = tierutil.get_size_freqdim_of_tier(n_mels=hp.audio.mel_channels, n_tiers=hp.network.n_tiers, tier=tier) if tier == 1: model = Tier1(tier=tier, n_layers=hp.network.layers[tier - 1], hidden_size=hp.network.hidden_size, gmm_size=hp.network.gmm_size, freq=tier_freq) else: model = Tier(tier=tier, n_layers=hp.network.layers[tier - 1], hidden_size=hp.network.hidden_size, gmm_size=hp.network.gmm_size, freq=tier_freq) model = model.to(hp.device) model.train() parameters = model.parameters() # Setup loss criterion and optimizer criterion = GMMLoss() optimizer = torch.optim.RMSprop(params=parameters, lr=hp.training.lr, momentum=hp.training.momentum) # Check if training has to be resumed from previous checkpoint if args.checkpoint_path is not None: model, optimizer = resume_training(args, hp, tier, model, optimizer, logger) else: logger.info( f"Starting new training on dataset {hp.data.dataset} with configuration file " f"name {hp.name}") # Train the tier total_iterations = 0 loss_logging = 0 # accumulated loss between logging iterations loss_save = 0 # accumulated loss between saving iterations prev_loss_onesample = 1e8 # used to compare between saving iterations and decide whether or not # to save the model gradients = [] for epoch in range(hp.training.epochs): logger.info(f"Epoch: {epoch}/{hp.training.epochs} - Starting") for i, (waveform, utterance) in enumerate(train_dataloader): # 1.1 Transform waveform input to melspectrogram and apply preprocessing to normalize waveform = waveform.to(device=hp.device, non_blocking=True) spectrogram = transforms.wave_to_melspectrogram(waveform, hp) spectrogram = audio_normalizing.preprocessing(spectrogram, hp) # 1.2 Get input and output from the original spectrogram for this tier input_spectrogram, output_spectrogram = tierutil.split( spectrogram=spectrogram, tier=tier, n_tiers=hp.network.n_tiers) length_spectrogram = input_spectrogram.size(2) # if item is too long, we jump to the next one if length_spectrogram > 1000: continue # 2. Compute the model output if tier == 1: # generation is unconditional so there is only one input mu_hat, std_hat, pi_hat = model(spectrogram=input_spectrogram) else: # generation is conditional on the spectrogram generated by previous tiers mu_hat, std_hat, pi_hat = model( spectrogram=output_spectrogram, spectrogram_prev_tier=input_spectrogram) # gpumemory.stat_cuda("Forward") # 3. Calculate the loss loss = criterion(mu=mu_hat, std=std_hat, pi=pi_hat, target=output_spectrogram) # gpumemory.stat_cuda("Loss") del spectrogram del mu_hat, std_hat, pi_hat # 3.1 Check if loss has exploded if torch.isnan(loss) or torch.isinf(loss): error_msg = f"Loss exploded at Epoch: {epoch}/{hp.training.epochs} - " \ f"Iteration: {i * hp.training.batch_size}/{num_samples}" logger.error(error_msg) raise Exception(error_msg) # 4. Compute gradients loss_cpu = loss.item() loss = loss / hp.training.accumulation_steps loss.backward() # 5. Perform backpropagation (using gradient accumulation so efective batch size is the # same as in the paper) if (total_iterations + 1) % (hp.training.accumulation_steps / hp.training.batch_size) == 0: gradients.append(gradient_norm(model)) avg_gradient = sum(gradients) / len(gradients) logger.info(f"Gradient norm: {gradients[-1]} - " f"Avg gradient: {avg_gradient}") torch.nn.utils.clip_grad_norm_(parameters, 2200) optimizer.step() model.zero_grad() # 6. Logging and saving model loss_oneframe = loss_cpu / (length_spectrogram * hp.training.batch_size) loss_logging += loss_oneframe # accumulated loss between logging iterations loss_save += loss_oneframe # accumulated loss between saving iterations # 6.1 Save model (if is better than previous tier) if (total_iterations + 1) % hp.training.save_iterations == 0: # Calculate average loss of one sample of a batch loss_onesample = loss_save / hp.training.save_iterations # if loss_onesample of these iterations is lower, the tier is better and we save it if loss_onesample <= prev_loss_onesample: path = f"{hp.training.dir_chkpt}/tier{tier}_{timestamp}_loss{loss_onesample:.2f}.pt" torch.save(obj={ 'dataset': hp.data.dataset, 'tier_idx': tier, 'hp': hp, 'epoch': epoch, 'iterations': i, 'total_iterations': total_iterations, 'tier': model.state_dict(), 'optimizer': optimizer.state_dict() }, f=path) logger.info(f"Model saved to: {path}") prev_loss_onesample = loss_onesample loss_save = 0 # 6.2 Logging if (total_iterations + 1) % hp.logging.log_iterations == 0: # Calculate average loss of one sample of a batch loss_onesample = loss_logging / hp.logging.log_iterations tensorboardwriter.log_training(hp, loss_onesample, total_iterations) logger.info( f"Epoch: {epoch}/{hp.training.epochs} - " f"Iteration: {i * hp.training.batch_size}/{num_samples} - " f"Loss: {loss_onesample:.4f}") loss_logging = 0 # 6.3 Evaluate if (total_iterations + 1) % hp.training.evaluation_iterations == 0: evaluation(hp, tier, test_dataloader, model, criterion, logger) total_iterations += 1 # After finishing training: save model, hyperparameters and total loss path = f"{hp.training.dir_chkpt}/tier{tier}_{timestamp}_epoch{epoch}_final.pt" torch.save(obj={ 'dataset': hp.data.dataset, 'tier_idx': tier, 'hp': hp, 'epoch': epoch, 'iterations': evaluation(hp, tier, test_dataloader, model, criterion, logger), 'total_iterations': total_iterations, 'tier': model.state_dict(), 'optimizer': optimizer.state_dict() }, f=path) logger.info(f"Model saved to: {path}") tensorboardwriter.log_end_training(hp=hp, loss=-1) logger.info("Finished training")
def resume_training(args: argparse.Namespace, hp: HParams, tier: int, model: Tier, optimizer: torch.optim.Optimizer, logger: logging.Logger) \ -> Tuple[Tier, torch.optim.Optimizer]: """ Loads the model specified in args.checkpoint_path to resume training from that point. Args: args (argparse.Namespace): parameters to set up the training. At least, args must contain: args = {"path_config": ..., "tier": ..., "checkpoint_path": ...} hp (HParams): hyperparameters for the model and other parameters (training, dataset, ...) tier (int): number of the tier to load. model (Tier): model where the weights will be loaded. optimizer (torch.optim.Optimizer): optimizer where the information will be loaded. logger (logging.Logger): to log general information about resuming the training. Returns: model (Tier) and optimizer (torch.optim.Optimizer) """ if not Path(args.checkpoint_path).exists(): logger.error( f"Path for resuming training {args.checkpoint_path} does not exist." ) raise Exception( f"Path for resuming training {args.checkpoint_path} does not exist." ) logger.info(f"Resuming training with weights from: {args.checkpoint_path}") checkpoint = torch.load(args.checkpoint_path) hp_chkpt = checkpoint["hp"] # Check if current hyperparameters and the ones from saved model are the same if hp_chkpt.audio != hp.audio: logger.warning("New params for audio are different from checkpoint. " "It will use new params.") if hp_chkpt.network != hp.network: logger.error( "New params for network structure are different from checkpoint.") # raise Exception("New params for network structure are different from checkpoint.") if checkpoint["tier_idx"] != tier: logger.error( f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})." ) raise Exception( f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})." ) if hp_chkpt.data != hp.data: logger.warning("New params for dataset are different from checkpoint. " "It will use new params.") if hp_chkpt.training != hp.training: logger.warning( "New params for training are different from checkpoint. " "It will use new params.") # epoch_chkpt = checkpoint["epoch"] # iterations_chkpt = checkpoint["iterations"] # total_iterations_chkpt = checkpoint["total_iterations"] model.load_state_dict(checkpoint["tier"]) optimizer.load_state_dict(checkpoint["optimizer"]) return model, optimizer