def forward(self, input, target, length):
        """
        Args:
            input: A Variable containing a FloatTensor of size
                (batch, max_len, dim) which contains the
                unnormalized probability for each class.
            target: A Variable containing a LongTensor of size
                (batch, max_len, dim) which contains the index of the true
                class for each corresponding step.
            length: A Variable containing a LongTensor of size (batch,)
                which contains the length of each data in a batch.
        Returns:
            loss: An average loss value masked by the length.
        """
        input = input.contiguous()
        target = target.contiguous()

        # logits_flat: (batch * max_len, dim)
        input = input.view(-1, input.shape[-1])
        # target_flat: (batch * max_len, dim)
        target_flat = target.view(-1, target.shape[-1])
        # losses_flat: (batch * max_len, dim)
        losses_flat = functional.mse_loss(input,
                                          target_flat,
                                          size_average=False,
                                          reduce=False)
        # losses: (batch, max_len, dim)
        losses = losses_flat.view(*target.size())

        # mask: (batch, max_len, 1)
        mask = sequence_mask(sequence_length=length,
                             max_len=target.size(1)).unsqueeze(2)
        losses = losses * mask.float()
        loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
        return loss
示例#2
0
 def forward(self, input, target, length):
     """
     Args:
         input: A Variable containing a FloatTensor of size
             (batch, max_len, dim) which contains the
             unnormalized probability for each class.
         target: A Variable containing a LongTensor of size
             (batch, max_len, dim) which contains the index of the true
             class for each corresponding step.
         length: A Variable containing a LongTensor of size (batch,)
             which contains the length of each data in a batch.
     Returns:
         loss: An average loss value masked by the length.
     """
     # mask: (batch, max_len, 1)
     use_half_mask_scalor = 0.005  # if 0.01 would casue Nan somehow
     target.requires_grad = False
     mask = sequence_mask(sequence_length=length,
                          max_len=target.size(1)).unsqueeze(2).float()
     mask = mask.expand_as(input).type_as(
         input
     ) * use_half_mask_scalor if input.dtype == torch.float16 else mask.expand_as(
         input).type_as(input)
     loss = functional.mse_loss((input * mask), (target * mask),
                                reduction="sum")
     loss = loss / mask.sum()
     return loss
示例#3
0
 def forward(self, characters, text_lengths, mel_specs):
     B = characters.size(0)
     mask = sequence_mask(text_lengths).to(characters.device)
     inputs = self.embedding(characters)
     encoder_outputs = self.encoder(inputs)
     mel_outputs, alignments, stop_tokens = self.decoder(
         encoder_outputs, mel_specs, mask)
     mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
     linear_outputs = self.postnet(mel_outputs)
     linear_outputs = self.last_linear(linear_outputs)
     return mel_outputs, linear_outputs, alignments, stop_tokens
示例#4
0
 def forward(self, text, text_lengths, mel_specs=None):
     # compute mask for padding
     mask = sequence_mask(text_lengths).to(text.device)
     embedded_inputs = self.embedding(text).transpose(1, 2)
     encoder_outputs = self.encoder(embedded_inputs, text_lengths)
     mel_outputs, stop_tokens, alignments = self.decoder(
         encoder_outputs, mel_specs, mask)
     mel_outputs_postnet = self.postnet(mel_outputs)
     mel_outputs_postnet = mel_outputs + mel_outputs_postnet
     mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
         mel_outputs, mel_outputs_postnet, alignments)
     return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
示例#5
0
 def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
     B = characters.size(0)
     mask = sequence_mask(text_lengths).to(characters.device)
     inputs = self.embedding(characters)
     encoder_outputs = self.encoder(inputs)
     encoder_outputs = self._add_speaker_embedding(encoder_outputs,
                                                   speaker_ids)
     gst_outputs = self.gst(mel_specs)
     gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
     encoder_outputs = encoder_outputs + gst_outputs
     mel_outputs, alignments, stop_tokens = self.decoder(
         encoder_outputs, mel_specs, mask)
     mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
     linear_outputs = self.postnet(mel_outputs)
     linear_outputs = self.last_linear(linear_outputs)
     return mel_outputs, linear_outputs, alignments, stop_tokens
示例#6
0
    def test_in_out(self):
        layer = L1LossMasked()
        dummy_input = T.ones(4, 8, 128).float()
        dummy_target = T.ones(4, 8, 128).float()
        dummy_length = (T.ones(4) * 8).long()
        output = layer(dummy_input, dummy_target, dummy_length)
        assert output.item() == 0.0

        dummy_input = T.ones(4, 8, 128).float()
        dummy_target = T.zeros(4, 8, 128).float()
        dummy_length = (T.ones(4) * 8).long()
        output = layer(dummy_input, dummy_target, dummy_length)
        assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
        dummy_input = T.ones(4, 8, 128).float()
        dummy_target = T.zeros(4, 8, 128).float()
        dummy_length = (T.arange(5, 9)).long()
        mask = (
            (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
        output = layer(dummy_input + mask, dummy_target, dummy_length)
        assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
 def forward(self, input, target, length):
     """
     Args:
         input: A Variable containing a FloatTensor of size
             (batch, max_len, dim) which contains the
             unnormalized probability for each class.
         target: A Variable containing a LongTensor of size
             (batch, max_len, dim) which contains the index of the true
             class for each corresponding step.
         length: A Variable containing a LongTensor of size (batch,)
             which contains the length of each data in a batch.
     Returns:
         loss: An average loss value masked by the length.
     """
     # mask: (batch, max_len, 1)
     mask = sequence_mask(sequence_length=length,
                          max_len=target.size(1)).unsqueeze(2).float()
     mask = mask.expand_as(input)
     loss = functional.l1_loss(input * mask, target * mask, reduction="sum")
     loss = loss / mask.sum()
     return loss
示例#8
0
文件: train.py 项目: wurde/TTS
def train(model, criterion, criterion_st, optimizer, optimizer_st,
          scheduler, ap, epoch):
    data_loader = setup_loader(is_val=False)
    model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0
    avg_stop_loss = 0
    avg_step_time = 0
    print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    n_priority_freq = int(
        3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
    batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[2]
        mel_input = data[3]
        mel_lengths = data[4]
        stop_targets = data[5]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(non_blocking=True)
            stop_targets = stop_targets.cuda(non_blocking=True)

        # compute mask for padding
        mask = sequence_mask(text_lengths)

        # forward pass
        if use_cuda:
            mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
                model, (text_input, mel_input, mask))
        else:
            mel_output, linear_output, alignments, stop_tokens = model(
                text_input, mel_input, mask)

        # loss computation
        stop_loss = criterion_st(stop_tokens, stop_targets)
        mel_loss = criterion(mel_output, mel_input, mel_lengths)
        linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
            + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
                              linear_input[:, :, :n_priority_freq],
                              mel_lengths)
        loss = mel_loss + linear_loss

        # backpass and check the grad norm for spec losses
        loss.backward(retain_graph=True)
        # custom weight decay
        for group in optimizer.param_groups:
            for param in group['params']:
                current_lr = group['lr']
                param.data = param.data.add(-c.wd * group['lr'], param.data)
        grad_norm, skip_flag = check_update(model, 1)
        if skip_flag:
            optimizer.zero_grad()
            print("   | > Iteration skipped!!", flush=True)
            continue
        optimizer.step()

        # backpass and check the grad norm for stop loss
        stop_loss.backward()
        # custom weight decay
        for group in optimizer_st.param_groups:
            for param in group['params']:
                param.data = param.data.add(-c.wd * group['lr'], param.data)
        grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
        if skip_flag:
            optimizer_st.zero_grad()
            print("   | > Iteration skipped fro stopnet!!")
            continue
        optimizer_st.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        if current_step % c.print_step == 0:
            print(
                "   | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  LinearLoss:{:.5f}  "
                "MelLoss:{:.5f}  StopLoss:{:.5f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  LR:{:.6f}".format(
                    num_iter, batch_n_iter, current_step, loss.item(),
                    linear_loss.item(), mel_loss.item(), stop_loss.item(),
                    grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
                flush=True)

        avg_linear_loss += linear_loss.item()
        avg_mel_loss += mel_loss.item()
        avg_stop_loss += stop_loss.item()
        avg_step_time += step_time

        # Plot Training Iter Stats
        tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
        tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(),
                      current_step)
        tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step)
        tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
                      current_step)
        tb.add_scalar('Params/GradNorm', grad_norm, current_step)
        tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
        tb.add_scalar('Time/StepTime', step_time, current_step)

        if current_step % c.save_step == 0:
            if c.checkpoint:
                # save model
                save_checkpoint(model, optimizer, optimizer_st,
                                linear_loss.item(), OUT_PATH, current_step,
                                epoch)

            # Diagnostic visualizations
            const_spec = linear_output[0].data.cpu().numpy()
            gt_spec = linear_input[0].data.cpu().numpy()

            const_spec = plot_spectrogram(const_spec, ap)
            gt_spec = plot_spectrogram(gt_spec, ap)
            tb.add_figure('Visual/Reconstruction', const_spec, current_step)
            tb.add_figure('Visual/GroundTruth', gt_spec, current_step)

            align_img = alignments[0].data.cpu().numpy()
            align_img = plot_alignment(align_img)
            tb.add_figure('Visual/Alignment', align_img, current_step)

            # Sample audio
            audio_signal = linear_output[0].data.cpu().numpy()
            ap.griffin_lim_iters = 60
            audio_signal = ap.inv_spectrogram(audio_signal.T)
            try:
                tb.add_audio(
                    'SampleAudio',
                    audio_signal,
                    current_step,
                    sample_rate=c.sample_rate)
            except:
                pass

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)

    # print epoch stats
    print(
        "   | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
        "AvgLinearLoss:{:.5f}  AvgMelLoss:{:.5f}  "
        "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
        "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
                                    avg_linear_loss, avg_mel_loss,
                                    avg_stop_loss, epoch_time, avg_step_time),
        flush=True)

    # Plot Training Epoch Stats
    tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
    tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
    tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
    tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
    tb.add_scalar('Time/EpochTime', epoch_time, epoch)
    epoch_time = 0
    return avg_linear_loss, current_step
示例#9
0
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
          ap, epoch):
    data_loader = setup_loader(is_val=False, verbose=(epoch == 0))
    model.train()
    epoch_time = 0
    avg_linear_loss = 0
    avg_mel_loss = 0
    avg_stop_loss = 0
    avg_step_time = 0
    print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
    n_priority_freq = int(3000 / (c.audio['sample_rate'] * 0.5) *
                          c.audio['num_freq'])
    if num_gpus > 0:
        batch_n_iter = int(
            len(data_loader.dataset) / (c.batch_size * num_gpus))
    else:
        batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
    for num_iter, data in enumerate(data_loader):
        start_time = time.time()

        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[2]
        mel_input = data[3]
        mel_lengths = data[4]
        stop_targets = data[5]
        avg_text_length = torch.mean(text_lengths.float())
        avg_spec_length = torch.mean(mel_lengths.float())

        # set stop targets view, we predict a single stop token per r frames prediction
        stop_targets = stop_targets.view(text_input.shape[0],
                                         stop_targets.size(1) // c.r, -1)
        stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()

        current_step = num_iter + args.restore_step + \
            epoch * len(data_loader) + 1

        # setup lr
        if c.lr_decay:
            scheduler.step()
        optimizer.zero_grad()
        optimizer_st.zero_grad()

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda(non_blocking=True)
            text_lengths = text_lengths.cuda(non_blocking=True)
            mel_input = mel_input.cuda(non_blocking=True)
            mel_lengths = mel_lengths.cuda(non_blocking=True)
            linear_input = linear_input.cuda(non_blocking=True)
            stop_targets = stop_targets.cuda(non_blocking=True)

        # compute mask for padding
        mask = sequence_mask(text_lengths)

        # forward pass
        mel_output, linear_output, alignments, stop_tokens = model(
            text_input, mel_input, mask)

        # loss computation
        stop_loss = criterion_st(stop_tokens, stop_targets)
        mel_loss = criterion(mel_output, mel_input, mel_lengths)
        linear_loss = (1 - c.loss_weight) * criterion(linear_output, linear_input, mel_lengths)\
            + c.loss_weight * criterion(linear_output[:, :, :n_priority_freq],
                              linear_input[:, :, :n_priority_freq],
                              mel_lengths)
        loss = mel_loss + linear_loss

        # backpass and check the grad norm for spec losses
        loss.backward(retain_graph=True)
        optimizer, current_lr = weight_decay(optimizer, c.wd)
        grad_norm, _ = check_update(model, 1.0)
        optimizer.step()

        # backpass and check the grad norm for stop loss
        stop_loss.backward()
        optimizer_st, _ = weight_decay(optimizer_st, c.wd)
        grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
        optimizer_st.step()

        step_time = time.time() - start_time
        epoch_time += step_time

        if current_step % c.print_step == 0:
            print(
                " | > Step:{}/{}  GlobalStep:{}  TotalLoss:{:.5f}  LinearLoss:{:.5f}  "
                "MelLoss:{:.5f}  StopLoss:{:.5f}  GradNorm:{:.5f}  "
                "GradNormST:{:.5f}  AvgTextLen:{:.1f}  AvgSpecLen:{:.1f}  StepTime:{:.2f}  LR:{:.6f}"
                .format(num_iter, batch_n_iter, current_step, loss.item(),
                        linear_loss.item(), mel_loss.item(), stop_loss.item(),
                        grad_norm, grad_norm_st, avg_text_length,
                        avg_spec_length, step_time, current_lr),
                flush=True)

        # aggregate losses from processes
        if num_gpus > 1:
            linear_loss = reduce_tensor(linear_loss.data, num_gpus)
            mel_loss = reduce_tensor(mel_loss.data, num_gpus)
            loss = reduce_tensor(loss.data, num_gpus)
            stop_loss = reduce_tensor(stop_loss.data, num_gpus)

        if args.rank == 0:
            avg_linear_loss += float(linear_loss.item())
            avg_mel_loss += float(mel_loss.item())
            avg_stop_loss += stop_loss.item()
            avg_step_time += step_time

            # Plot Training Iter Stats
            iter_stats = {
                "loss_posnet": linear_loss.item(),
                "loss_decoder": mel_loss.item(),
                "lr": current_lr,
                "grad_norm": grad_norm,
                "grad_norm_st": grad_norm_st,
                "step_time": step_time
            }
            tb_logger.tb_train_iter_stats(current_step, iter_stats)

            if current_step % c.save_step == 0:
                if c.checkpoint:
                    # save model
                    save_checkpoint(model, optimizer, optimizer_st,
                                    linear_loss.item(), OUT_PATH, current_step,
                                    epoch)

                # Diagnostic visualizations
                const_spec = linear_output[0].data.cpu().numpy()
                gt_spec = linear_input[0].data.cpu().numpy()
                align_img = alignments[0].data.cpu().numpy()

                figures = {
                    "prediction": plot_spectrogram(const_spec, ap),
                    "ground_truth": plot_spectrogram(gt_spec, ap),
                    "alignment": plot_alignment(align_img)
                }
                tb_logger.tb_train_figures(current_step, figures)

                # Sample audio
                tb_logger.tb_train_audios(
                    current_step,
                    {'TrainAudio': ap.inv_spectrogram(const_spec.T)},
                    c.audio["sample_rate"])

    avg_linear_loss /= (num_iter + 1)
    avg_mel_loss /= (num_iter + 1)
    avg_stop_loss /= (num_iter + 1)
    avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
    avg_step_time /= (num_iter + 1)

    # print epoch stats
    print(" | > EPOCH END -- GlobalStep:{}  AvgTotalLoss:{:.5f}  "
          "AvgLinearLoss:{:.5f}  AvgMelLoss:{:.5f}  "
          "AvgStopLoss:{:.5f}  EpochTime:{:.2f}  "
          "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
                                      avg_linear_loss, avg_mel_loss,
                                      avg_stop_loss, epoch_time,
                                      avg_step_time),
          flush=True)

    # Plot Epoch Stats
    if args.rank == 0:
        # Plot Training Epoch Stats
        epoch_stats = {
            "loss_postnet": avg_linear_loss,
            "loss_decoder": avg_mel_loss,
            "stop_loss": avg_stop_loss,
            "epoch_time": epoch_time
        }
        tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
        if c.tb_model_param_stats:
            tb_logger.tb_model_weights(model, current_step)
    return avg_linear_loss, current_step
示例#10
0
                mel_lengths = data[4]
                stop_targets = data[5]
                idxs = data[6]

                # set stop targets view, we predict a single stop token per r frames prediction
                stop_targets = stop_targets.view(text_input.shape[0],
                                                    stop_targets.size(1) // c.r,
                                                    -1)
                stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()

                # dispatch data to GPU
                if use_cuda:
                    text_input = text_input.cuda(non_blocking=True)
                    text_lengths = text_lengths.cuda(non_blocking=True)
                    mel_input = mel_input.cuda(non_blocking=True)
                    mel_lengths = mel_lengths.cuda(non_blocking=True)
                    linear_input = linear_input.cuda(non_blocking=True)
                    stop_targets = stop_targets.cuda(non_blocking=True)
                mask = sequence_mask(text_lengths)

                # forward pass
                mel_output, linear_output, alignments, stop_tokens =\
                    model.forward(text_input, mel_input, mask)

                for i, alignment in enumerate(alignments):
                    alignment = alignment.data.cpu().numpy()
                    plot_alignment(os.path.join(plot_folder, 'alignment-{}.png'.format(idxs[i]), alignment)
                    duration = get_duration(alignment.T) * c.r
                    assert (duration.sum() == mel_input.shape[1])
                    np.save(os.path.join(duration_folder, 'duration-{}.npy'.format(idxs[i])), duration)