コード例 #1
0
    def __init__(self, n_tiers: int, layers: List[int], hidden_size: int,
                 gmm_size: int, freq: int):
        """
        Args:
            n_tiers (int): number of tiers the model is composed of
            layers (List[int]): list with the layers of every tier
            hidden_size (int): parameter for the hidden_state of the Delayed Stack Layers and other
                               and other sizes
            gmm_size (int): number of mixture components of the GMM
            freq (int): size of the frequency axis of the spectrogram to generate. See note in the
                        documentation of the file.
        """
        super(MelNet, self).__init__()

        self.n_tiers = n_tiers
        self.layers = layers
        self.hidden_size = hidden_size
        self.gmm_size = gmm_size

        assert freq >= 2 ** (self.n_tiers / 2), "Size of frequency axis is too small for " \
                                                "being generated with the number of tiers " \
                                                "of this model"
        self.freq = freq

        self.tiers = nn.ModuleList([
            Tier1(
                tier=1,
                n_layers=layers[0],
                hidden_size=hidden_size,
                gmm_size=gmm_size,
                # Calculate size of FREQ dimension for this tier
                freq=tierutil.get_size_freqdim_of_tier(
                    n_mels=self.freq, n_tiers=self.n_tiers, tier=1))
        ] + [
            Tier(
                tier=tier_idx,
                n_layers=layers[tier_idx],
                hidden_size=hidden_size,
                gmm_size=gmm_size,
                # Calculate size of FREQ dimension for this tier
                freq=tierutil.get_size_freqdim_of_tier(
                    n_mels=self.freq, n_tiers=self.n_tiers, tier=tier_idx + 1))
            for tier_idx in range(1, n_tiers)
        ])
コード例 #2
0
def train_tier(args: argparse.Namespace, hp: HParams, tier: int,
               extension_architecture: str, timestamp: str,
               tensorboardwriter: TensorboardWriter,
               logger: logging.Logger) -> None:
    """
    Trains one tier of MelNet.

    Args:
        args (argparse.Namespace): parameters to set up the training. At least, args must contain:
                                   args = {"path_config": ...,
                                           "tier": ...,
                                           "checkpoint_path": ...}
        hp (HParams): hyperparameters for the model and other parameters (training, dataset, ...)
        tier (int): number of the tier to train.
        extension_architecture (str): information about the network's architecture of this run
                                      (training) to identify the logs and weights of the model.
        timestamp (str): information that identifies completely this run (training).
        tensorboardwriter (TensorboardWriter): to log information about training to tensorboard.
        logger (logging.Logger): to log general information about the training of the model.
    """
    logger.info(f"Start training of tier {tier}/{hp.network.n_tiers}")

    # Setup the data ready to be consumed
    train_dataloader, test_dataloader, num_samples = get_dataloader(hp)

    # Setup tier
    # Calculate size of FREQ dimension for this tier
    tier_freq = tierutil.get_size_freqdim_of_tier(n_mels=hp.audio.mel_channels,
                                                  n_tiers=hp.network.n_tiers,
                                                  tier=tier)

    if tier == 1:
        model = Tier1(tier=tier,
                      n_layers=hp.network.layers[tier - 1],
                      hidden_size=hp.network.hidden_size,
                      gmm_size=hp.network.gmm_size,
                      freq=tier_freq)
    else:
        model = Tier(tier=tier,
                     n_layers=hp.network.layers[tier - 1],
                     hidden_size=hp.network.hidden_size,
                     gmm_size=hp.network.gmm_size,
                     freq=tier_freq)
    model = model.to(hp.device)
    model.train()
    parameters = model.parameters()

    # Setup loss criterion and optimizer
    criterion = GMMLoss()
    optimizer = torch.optim.RMSprop(params=parameters,
                                    lr=hp.training.lr,
                                    momentum=hp.training.momentum)

    # Check if training has to be resumed from previous checkpoint
    if args.checkpoint_path is not None:
        model, optimizer = resume_training(args, hp, tier, model, optimizer,
                                           logger)
    else:
        logger.info(
            f"Starting new training on dataset {hp.data.dataset} with configuration file "
            f"name {hp.name}")

    # Train the tier
    total_iterations = 0
    loss_logging = 0  # accumulated loss between logging iterations
    loss_save = 0  # accumulated loss between saving iterations
    prev_loss_onesample = 1e8  # used to compare between saving iterations and decide whether or not
    # to save the model

    gradients = []

    for epoch in range(hp.training.epochs):
        logger.info(f"Epoch: {epoch}/{hp.training.epochs} - Starting")
        for i, (waveform, utterance) in enumerate(train_dataloader):

            # 1.1 Transform waveform input to melspectrogram and apply preprocessing to normalize
            waveform = waveform.to(device=hp.device, non_blocking=True)
            spectrogram = transforms.wave_to_melspectrogram(waveform, hp)
            spectrogram = audio_normalizing.preprocessing(spectrogram, hp)
            # 1.2 Get input and output from the original spectrogram for this tier
            input_spectrogram, output_spectrogram = tierutil.split(
                spectrogram=spectrogram, tier=tier, n_tiers=hp.network.n_tiers)
            length_spectrogram = input_spectrogram.size(2)
            # if item is too long, we jump to the next one
            if length_spectrogram > 1000:
                continue

            # 2. Compute the model output
            if tier == 1:
                # generation is unconditional so there is only one input
                mu_hat, std_hat, pi_hat = model(spectrogram=input_spectrogram)
            else:
                # generation is conditional on the spectrogram generated by previous tiers
                mu_hat, std_hat, pi_hat = model(
                    spectrogram=output_spectrogram,
                    spectrogram_prev_tier=input_spectrogram)
            # gpumemory.stat_cuda("Forward")
            # 3. Calculate the loss
            loss = criterion(mu=mu_hat,
                             std=std_hat,
                             pi=pi_hat,
                             target=output_spectrogram)
            # gpumemory.stat_cuda("Loss")
            del spectrogram
            del mu_hat, std_hat, pi_hat

            # 3.1 Check if loss has exploded
            if torch.isnan(loss) or torch.isinf(loss):
                error_msg = f"Loss exploded at Epoch: {epoch}/{hp.training.epochs} - " \
                            f"Iteration: {i * hp.training.batch_size}/{num_samples}"
                logger.error(error_msg)
                raise Exception(error_msg)

            # 4. Compute gradients
            loss_cpu = loss.item()
            loss = loss / hp.training.accumulation_steps
            loss.backward()

            # 5. Perform backpropagation (using gradient accumulation so efective batch size is the
            # same as in the paper)
            if (total_iterations + 1) % (hp.training.accumulation_steps /
                                         hp.training.batch_size) == 0:

                gradients.append(gradient_norm(model))
                avg_gradient = sum(gradients) / len(gradients)
                logger.info(f"Gradient norm: {gradients[-1]} - "
                            f"Avg gradient: {avg_gradient}")
                torch.nn.utils.clip_grad_norm_(parameters, 2200)
                optimizer.step()
                model.zero_grad()

            # 6. Logging and saving model
            loss_oneframe = loss_cpu / (length_spectrogram *
                                        hp.training.batch_size)
            loss_logging += loss_oneframe  # accumulated loss between logging iterations
            loss_save += loss_oneframe  # accumulated loss between saving iterations

            # 6.1 Save model (if is better than previous tier)
            if (total_iterations + 1) % hp.training.save_iterations == 0:
                # Calculate average loss of one sample of a batch
                loss_onesample = loss_save / hp.training.save_iterations
                # if loss_onesample of these iterations is lower, the tier is better and we save it
                if loss_onesample <= prev_loss_onesample:
                    path = f"{hp.training.dir_chkpt}/tier{tier}_{timestamp}_loss{loss_onesample:.2f}.pt"
                    torch.save(obj={
                        'dataset': hp.data.dataset,
                        'tier_idx': tier,
                        'hp': hp,
                        'epoch': epoch,
                        'iterations': i,
                        'total_iterations': total_iterations,
                        'tier': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                               f=path)
                    logger.info(f"Model saved to: {path}")
                    prev_loss_onesample = loss_onesample
                loss_save = 0

            # 6.2 Logging
            if (total_iterations + 1) % hp.logging.log_iterations == 0:
                # Calculate average loss of one sample of a batch
                loss_onesample = loss_logging / hp.logging.log_iterations
                tensorboardwriter.log_training(hp, loss_onesample,
                                               total_iterations)
                logger.info(
                    f"Epoch: {epoch}/{hp.training.epochs} - "
                    f"Iteration: {i * hp.training.batch_size}/{num_samples} - "
                    f"Loss: {loss_onesample:.4f}")
                loss_logging = 0

            # 6.3 Evaluate
            if (total_iterations + 1) % hp.training.evaluation_iterations == 0:
                evaluation(hp, tier, test_dataloader, model, criterion, logger)
            total_iterations += 1

        # After finishing training: save model, hyperparameters and total loss
        path = f"{hp.training.dir_chkpt}/tier{tier}_{timestamp}_epoch{epoch}_final.pt"
        torch.save(obj={
            'dataset':
            hp.data.dataset,
            'tier_idx':
            tier,
            'hp':
            hp,
            'epoch':
            epoch,
            'iterations':
            evaluation(hp, tier, test_dataloader, model, criterion, logger),
            'total_iterations':
            total_iterations,
            'tier':
            model.state_dict(),
            'optimizer':
            optimizer.state_dict()
        },
                   f=path)
        logger.info(f"Model saved to: {path}")
        tensorboardwriter.log_end_training(hp=hp, loss=-1)
        logger.info("Finished training")
コード例 #3
0
def resume_training(args: argparse.Namespace, hp: HParams, tier: int, model: Tier,
                    optimizer: torch.optim.Optimizer, logger: logging.Logger) \
        -> Tuple[Tier, torch.optim.Optimizer]:
    """
    Loads the model specified in args.checkpoint_path to resume training from that point.

    Args:
        args (argparse.Namespace): parameters to set up the training. At least, args must contain:
                                   args = {"path_config": ...,
                                           "tier": ...,
                                           "checkpoint_path": ...}
        hp (HParams): hyperparameters for the model and other parameters (training, dataset, ...)
        tier (int): number of the tier to load.
        model (Tier): model where the weights will be loaded.
        optimizer (torch.optim.Optimizer): optimizer where the information will be loaded.
        logger (logging.Logger): to log general information about resuming the training.

    Returns:
        model (Tier) and optimizer (torch.optim.Optimizer)
    """
    if not Path(args.checkpoint_path).exists():
        logger.error(
            f"Path for resuming training {args.checkpoint_path} does not exist."
        )
        raise Exception(
            f"Path for resuming training {args.checkpoint_path} does not exist."
        )

    logger.info(f"Resuming training with weights from: {args.checkpoint_path}")
    checkpoint = torch.load(args.checkpoint_path)
    hp_chkpt = checkpoint["hp"]

    # Check if current hyperparameters and the ones from saved model are the same
    if hp_chkpt.audio != hp.audio:
        logger.warning("New params for audio are different from checkpoint. "
                       "It will use new params.")

    if hp_chkpt.network != hp.network:
        logger.error(
            "New params for network structure are different from checkpoint.")
        # raise Exception("New params for network structure are different from checkpoint.")

    if checkpoint["tier_idx"] != tier:
        logger.error(
            f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})."
        )
        raise Exception(
            f"New tier to train ({tier}) is different from checkpoint ({checkpoint['tier']})."
        )

    if hp_chkpt.data != hp.data:
        logger.warning("New params for dataset are different from checkpoint. "
                       "It will use new params.")

    if hp_chkpt.training != hp.training:
        logger.warning(
            "New params for training are different from checkpoint. "
            "It will use new params.")

    # epoch_chkpt = checkpoint["epoch"]
    # iterations_chkpt = checkpoint["iterations"]
    # total_iterations_chkpt = checkpoint["total_iterations"]
    model.load_state_dict(checkpoint["tier"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    return model, optimizer