Esempio n. 1
0
def data_filter(args):

    model, device = load_model(args)

    start_t_test = time.time()

    # Decode
    test_set = SVSDataset_filter(
        align_root_path=args.test_align,
        pitch_beat_root_path=args.test_pitch,
        wav_root_path=args.test_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        standard=args.standard,
        sing_quality=args.sing_quality,
        Hz2semitone=args.Hz2semitone,
        semitone_min=args.semitone_min,
        semitone_max=args.semitone_max,
    )
    collate_fn_svs = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
        args.n_mels,
        args.db_joint,
        False,  # random crop
        -1,  # crop_min_length
        args.Hz2semitone,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=1,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs,
        pin_memory=True,
    )

    with torch.no_grad():
        for (
                step,
                data_step,
        ) in enumerate(test_loader, 1):
            if args.db_joint:
                (phone, beat, pitch, spec, real, imag, length, chars,
                 char_len_list, mel, singer_id, semitone,
                 filename_list) = data_step

            else:
                print(
                    "No support for augmentation with args.db_joint == False")
                quit()

            singer_id = np.array(singer_id).reshape(np.shape(phone)[0],
                                                    -1)  # [batch size, 1]
            singer_vec = singer_id.repeat(np.shape(phone)[1],
                                          axis=1)  # [batch size, length]
            singer_vec = torch.from_numpy(singer_vec).to(device)
            singer_id = torch.from_numpy(singer_id).to(device)

            phone = phone.to(device)
            beat = beat.to(device)
            pitch = pitch.to(device).float()
            if semitone is not None:
                semitone = semitone.to(device)
            spec = spec.to(device).float()
            mel = mel.to(device).float()
            real = real.to(device).float()
            imag = imag.to(device).float()
            length_mask = (length > 0).int().unsqueeze(2)
            length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float()
            length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
            length_mask = length_mask.to(device)
            length_mel_mask = length_mel_mask.to(device)
            length = length.to(device)
            char_len_list = char_len_list.to(device)

            if not args.use_asr_post:
                chars = chars.to(device)
                char_len_list = char_len_list.to(device)
            else:
                phone = phone.float()

            if args.Hz2semitone:
                pitch = semitone

            if args.normalize:
                sepc_normalizer = GlobalMVN(args.stats_file)
                mel_normalizer = GlobalMVN(args.stats_mel_file)
                spec, _ = sepc_normalizer(spec, length)
                mel, _ = mel_normalizer(mel, length)

            len_list, _ = torch.max(length, dim=1)  # [len1, len2, len3, ...]
            len_list = len_list.cpu().detach().numpy()

            singer_out, phone_out, semitone_out = model(spec, len_list)

            # calculate num
            batch_size = np.shape(spec)[0]

            singer_id = singer_id.view(-1)  # [batch size]
            _, singer_predict = torch.max(singer_out, dim=1)  # [batch size]
            singer_correct = singer_predict.eq(singer_id).cpu().sum().numpy()

            for i in range(batch_size):
                phone_i = phone[i, :len_list[i], :].view(-1)  # [valid seq len]
                phone_out_i = phone_out[
                    i, :len_list[i], :]  # [valid seq len, phone_size]
                _, phone_predict = torch.max(phone_out_i, dim=1)
                phone_correct = phone_predict.eq(phone_i).cpu().sum().numpy()

                semitone_i = semitone[i, :len_list[i], :].view(
                    -1)  # [valid seq len]
                semitone_out_i = semitone_out[
                    i, :len_list[i], :]  # [valid seq len, semitone_size]
                _, semitone_predict = torch.max(semitone_out_i, dim=1)
                semitone_correct = semitone_predict.eq(
                    semitone_i).cpu().sum().numpy()

                with open(os.path.join(args.prediction_path, "filter_res.txt"),
                          "a+") as f:
                    f.write(
                        f"{filename_list[i]}|{singer_predict[i]}|{phone_correct}|{semitone_correct}|{len_list[i]}\n"
                    )

                end = time.time()

                logging.info(
                    f"{filename_list[i]} -- sum_time: {(end - start_t_test)}s")
Esempio n. 2
0
def validate(dev_loader, model, device, criterion, perceptual_entropy, epoch,
             args, voc_model):
    """validate."""
    losses = AverageMeter()
    spec_losses = AverageMeter()
    if args.perceptual_loss > 0:
        pe_losses = AverageMeter()
    if args.n_mels > 0:
        mel_losses = AverageMeter()
        mcd_metric = AverageMeter()
        if args.double_mel_loss:
            double_mel_losses = AverageMeter()
    model.eval()

    log_save_dir = os.path.join(args.model_save_dir,
                                "epoch{}/log_val_figure".format(epoch))
    if not os.path.exists(log_save_dir):
        os.makedirs(log_save_dir)

    start = time.time()

    with torch.no_grad():
        for (step, data_step) in enumerate(dev_loader, 1):
            if args.db_joint:
                (
                    phone,
                    beat,
                    pitch,
                    spec,
                    real,
                    imag,
                    length,
                    chars,
                    char_len_list,
                    mel,
                    singer_id,
                ) = data_step

                singer_id = np.array(singer_id).reshape(
                    np.shape(phone)[0], -1)  # [batch size, 1]
                singer_vec = singer_id.repeat(np.shape(phone)[1],
                                              axis=1)  # [batch size, length]
                singer_vec = torch.from_numpy(singer_vec).to(device)

            else:
                (
                    phone,
                    beat,
                    pitch,
                    spec,
                    real,
                    imag,
                    length,
                    chars,
                    char_len_list,
                    mel,
                ) = data_step

            phone = phone.to(device)
            beat = beat.to(device)
            pitch = pitch.to(device).float()
            spec = spec.to(device).float()
            if mel is not None:
                mel = mel.to(device).float()
            real = real.to(device).float()
            imag = imag.to(device).float()
            length_mask = length.unsqueeze(2)
            if mel is not None:
                length_mel_mask = length_mask.repeat(1, 1,
                                                     mel.shape[2]).float()
                length_mel_mask = length_mel_mask.to(device)
            length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
            length_mask = length_mask.to(device)
            length = length.to(device)
            char_len_list = char_len_list.to(device)
            if not args.use_asr_post:
                chars = chars.to(device)
                char_len_list = char_len_list.to(device)
            else:
                phone = phone.float()

            if args.model_type == "GLU_Transformer":
                if args.db_joint:
                    output, att, output_mel, output_mel2 = model(
                        chars,
                        phone,
                        pitch,
                        beat,
                        singer_vec,
                        pos_char=char_len_list,
                        pos_spec=length,
                    )
                else:
                    output, att, output_mel, output_mel2 = model(
                        chars,
                        phone,
                        pitch,
                        beat,
                        pos_char=char_len_list,
                        pos_spec=length,
                    )
            elif args.model_type == "LSTM":
                if args.db_joint:
                    output, hidden, output_mel, output_mel2 = model(
                        phone, pitch, beat, singer_vec)
                else:
                    output, hidden, output_mel, output_mel2 = model(
                        phone, pitch, beat)
                att = None

            elif args.model_type == "GRU_gs":
                output, att, output_mel = model(spec, phone, pitch, beat,
                                                length, args)
                att = None
            elif args.model_type == "PureTransformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length)
            elif args.model_type == "Conformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length)
            elif args.model_type == "Comformer_full":
                if args.db_joint:
                    output, att, output_mel, output_mel2 = model(
                        chars,
                        phone,
                        pitch,
                        beat,
                        singer_vec,
                        pos_char=char_len_list,
                        pos_spec=length,
                    )
                else:
                    output, att, output_mel, output_mel2 = model(
                        chars,
                        phone,
                        pitch,
                        beat,
                        pos_char=char_len_list,
                        pos_spec=length,
                    )
            elif args.model_type == "USTC_DAR":
                output_mel = model(phone, pitch, beat, length, args)
                att = None

            spec_origin = spec.clone()
            mel_origin = mel.clone()
            if args.normalize:
                sepc_normalizer = GlobalMVN(args.stats_file)
                mel_normalizer = GlobalMVN(args.stats_mel_file)
                output_mel_normalizer = GlobalMVN(args.stats_mel_file)
                spec, _ = sepc_normalizer(spec, length)
                mel, _ = mel_normalizer(mel, length)

            if args.model_type == "USTC_DAR":
                spec_loss = 0
            else:
                spec_loss = criterion(output, spec, length_mask)

            if args.n_mels > 0:
                mel_loss = criterion(output_mel, mel, length_mel_mask)

                if args.double_mel_loss:
                    double_mel_loss = criterion(output_mel2, mel,
                                                length_mel_mask)
                else:
                    double_mel_loss = 0
            else:
                mel_loss = 0
                double_mel_loss = 0

            if args.vocoder_category == "wavernn":
                dev_loss = mel_loss + double_mel_loss
            else:
                dev_loss = mel_loss + double_mel_loss + spec_loss

            if args.perceptual_loss > 0:
                pe_loss = perceptual_entropy(output, real, imag)
                final_loss = (args.perceptual_loss * pe_loss +
                              (1 - args.perceptual_loss) * dev_loss)
            else:
                final_loss = dev_loss

            losses.update(final_loss.item(), phone.size(0))
            if args.model_type != "USTC_DAR":
                spec_losses.update(spec_loss.item(), phone.size(0))

            if args.perceptual_loss > 0:
                # pe_loss = perceptual_entropy(output, real, imag)
                pe_losses.update(pe_loss.item(), phone.size(0))
            if args.n_mels > 0:
                mel_losses.update(mel_loss.item(), phone.size(0))
                if args.double_mel_loss:
                    double_mel_losses.update(double_mel_loss.item(),
                                             phone.size(0))

            if args.model_type == "USTC_DAR":
                # normalize inverse stage
                if args.normalize and args.stats_file:
                    # output_mel, _ = mel_normalizer.inverse(output_mel, length)
                    mel, _ = mel_normalizer.inverse(mel, length)
                    output_mel, _ = output_mel_normalizer.inverse(
                        output_mel, length)
                mcd_value, length_sum = (
                    0,
                    1,
                )  # FIX ME! Calculate_melcd_fromMelSpectrum
            else:
                # normalize inverse stage
                if args.normalize and args.stats_file:
                    output, _ = sepc_normalizer.inverse(output, length)
                    # output_mel, _ = mel_normalizer.inverse(output_mel, length)
                    mel, _ = mel_normalizer.inverse(mel, length)
                    output_mel, _ = output_mel_normalizer.inverse(
                        output_mel, length)
                (mcd_value,
                 length_sum) = Metrics.Calculate_melcd_fromLinearSpectrum(
                     output, spec_origin, length, args)
            mcd_metric.update(mcd_value, length_sum)

            if step % args.dev_step_log == 0:
                if args.model_type == "USTC_DAR":
                    log_figure_mel(step, output_mel, mel_origin, att, length,
                                   log_save_dir, args)
                else:
                    if args.vocoder_category == "wavernn":
                        for i in range(output_mel.shape[0]):
                            one_batch_output_mel = output_mel[i].unsqueeze(0)
                            one_batch_mel = mel[i].unsqueeze(0)
                            log_mel(
                                step,
                                one_batch_output_mel,
                                one_batch_mel,
                                att,
                                length,
                                log_save_dir,
                                args,
                                voc_model,
                            )
                    else:
                        log_figure(step, output, spec_origin, att, length,
                                   log_save_dir, args)
                out_log = ("step {}: train_loss {:.4f}; "
                           "spec_loss {:.4f}; mcd_value {:.4f};".format(
                               step, losses.avg, spec_losses.avg,
                               mcd_metric.avg))
                if args.perceptual_loss > 0:
                    out_log += "pe_loss {:.4f}; ".format(pe_losses.avg)
                if args.n_mels > 0:
                    out_log += "mel_loss {:.4f}; ".format(mel_losses.avg)
                    if args.double_mel_loss:
                        out_log += "dmel_loss {:.4f}; ".format(
                            double_mel_losses.avg)
                end = time.time()
                print("{} -- sum_time: {}s".format(out_log, (end - start)))

    info = {
        "loss": losses.avg,
        "spec_loss": spec_losses.avg,
        "mcd_value": mcd_metric.avg,
    }
    if args.perceptual_loss > 0:
        info["pe_loss"] = pe_losses.avg
    if args.n_mels > 0:
        info["mel_loss"] = mel_losses.avg
    return info
Esempio n. 3
0
def train_one_epoch(
    train_loader,
    model,
    device,
    optimizer,
    criterion,
    perceptual_entropy,
    epoch,
    args,
    voc_model,
):
    """train_one_epoch."""
    losses = AverageMeter()
    spec_losses = AverageMeter()
    if args.perceptual_loss > 0:
        pe_losses = AverageMeter()
    if args.n_mels > 0:
        mel_losses = AverageMeter()
        # mcd_metric = AverageMeter()
        # f0_distortion_metric, vuv_error_metric =
        # AverageMeter(), AverageMeter()
        if args.double_mel_loss:
            double_mel_losses = AverageMeter()
    model.train()

    log_save_dir = os.path.join(args.model_save_dir,
                                "epoch{}/log_train_figure".format(epoch))
    if not os.path.exists(log_save_dir):
        os.makedirs(log_save_dir)

    start = time.time()

    # f0_ground_truth_all = np.reshape(np.array([]), (-1, 1))
    # f0_synthesis_all = np.reshape(np.array([]), (-1, 1))

    for (step, data_step) in enumerate(train_loader, 1):
        if args.db_joint:
            (
                phone,
                beat,
                pitch,
                spec,
                real,
                imag,
                length,
                chars,
                char_len_list,
                mel,
                singer_id,
            ) = data_step

            singer_id = np.array(singer_id).reshape(np.shape(phone)[0],
                                                    -1)  # [batch size, 1]
            singer_vec = singer_id.repeat(np.shape(phone)[1],
                                          axis=1)  # [batch size, length]
            singer_vec = torch.from_numpy(singer_vec).to(device)

        else:
            (
                phone,
                beat,
                pitch,
                spec,
                real,
                imag,
                length,
                chars,
                char_len_list,
                mel,
            ) = data_step
        phone = phone.to(device)
        beat = beat.to(device)
        pitch = pitch.to(device).float()
        spec = spec.to(device).float()
        if mel is not None:
            mel = mel.to(device).float()
        real = real.to(device).float()
        imag = imag.to(device).float()
        length_mask = length.unsqueeze(2)
        if mel is not None:
            length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float()
            length_mel_mask = length_mel_mask.to(device)
        length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
        length_mask = length_mask.to(device)
        length = length.to(device)
        char_len_list = char_len_list.to(device)

        if not args.use_asr_post:
            chars = chars.to(device)
            char_len_list = char_len_list.to(device)
        else:
            phone = phone.float()

        # output = [batch size, num frames, feat_dim]
        # output_mel = [batch size, num frames, n_mels dimension]
        if args.model_type == "GLU_Transformer":
            if args.db_joint:
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    singer_vec,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            else:
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length)
        elif args.model_type == "LSTM":
            if args.db_joint:
                output, hidden, output_mel, output_mel2 = model(
                    phone, pitch, beat, singer_vec)
            else:
                output, hidden, output_mel, output_mel2 = model(
                    phone, pitch, beat)
            att = None
        elif args.model_type == "GRU_gs":
            output, att, output_mel = model(spec, phone, pitch, beat, length,
                                            args)
            att = None
        elif args.model_type == "PureTransformer":
            output, att, output_mel, output_mel2 = model(
                chars,
                phone,
                pitch,
                beat,
                pos_char=char_len_list,
                pos_spec=length)
        elif args.model_type == "Conformer":
            # print(f"chars: {np.shape(chars)}, phone:
            # {np.shape(phone)}, length: {np.shape(length)}")
            output, att, output_mel, output_mel2 = model(
                chars,
                phone,
                pitch,
                beat,
                pos_char=char_len_list,
                pos_spec=length)
        elif args.model_type == "Comformer_full":
            if args.db_joint:
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    singer_vec,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            else:
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length)
        elif args.model_type == "USTC_DAR":
            output_mel = model(phone, pitch, beat, length,
                               args)  # mel loss written in spec loss
            att = None

        spec_origin = spec.clone()
        mel_origin = mel.clone()
        if args.normalize:
            sepc_normalizer = GlobalMVN(args.stats_file)
            mel_normalizer = GlobalMVN(args.stats_mel_file)
            output_mel_normalizer = GlobalMVN(args.stats_mel_file)
            spec, _ = sepc_normalizer(spec, length)
            mel, _ = mel_normalizer(mel, length)

        if args.model_type == "USTC_DAR":
            spec_loss = 0
        else:
            spec_loss = criterion(output, spec, length_mask)

        if args.n_mels > 0:
            mel_loss = criterion(output_mel, mel, length_mel_mask)
            if args.double_mel_loss:
                double_mel_loss = criterion(output_mel2, mel, length_mel_mask)
            else:
                double_mel_loss = 0
        else:
            mel_loss = 0
            double_mel_loss = 0
        if args.vocoder_category == "wavernn":
            train_loss = mel_loss + double_mel_loss
        else:
            train_loss = mel_loss + double_mel_loss + spec_loss
        if args.perceptual_loss > 0:
            pe_loss = perceptual_entropy(output, real, imag)
            final_loss = (args.perceptual_loss * pe_loss +
                          (1 - args.perceptual_loss) * train_loss)
        else:
            final_loss = train_loss

        final_loss = final_loss / args.accumulation_steps
        final_loss.backward()

        if args.gradclip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradclip)

        if (epoch + 1) % args.accumulation_steps == 0:
            if args.optimizer == "noam":
                optimizer.step_and_update_lr()
            else:
                optimizer.step()
            # 梯度清零
            optimizer.zero_grad()

        losses.update(final_loss.item(), phone.size(0))
        if args.model_type != "USTC_DAR":
            spec_losses.update(spec_loss.item(), phone.size(0))

        if args.perceptual_loss > 0:
            pe_losses.update(pe_loss.item(), phone.size(0))
        if args.n_mels > 0:
            mel_losses.update(mel_loss.item(), phone.size(0))
            if args.double_mel_loss:
                double_mel_losses.update(double_mel_loss.item(), phone.size(0))

        if step % args.train_step_log == 0:
            end = time.time()

            if args.model_type == "USTC_DAR":
                # normalize inverse 只在infer的时候用,因为log过程需要转换成wav,和计算mcd等指标
                if args.normalize and args.stats_file:
                    output_mel, _ = mel_normalizer.inverse(output_mel, length)
                    mel, _ = mel_normalizer.inverse(mel, length)
                    output_mel = output_mel_normalizer.inverse(
                        output_mel, length)
                log_figure_mel(step, output_mel, mel_origin, att, length,
                               log_save_dir, args)
                out_log = "step {}: train_loss {:.4f}; spec_loss {:.4f};".format(
                    step, losses.avg, spec_losses.avg)
            else:
                # normalize inverse 只在infer的时候用,因为log过程需要转换成wav,和计算mcd等指标
                if args.normalize and args.stats_file:
                    output, _ = sepc_normalizer.inverse(output, length)
                    mel, _ = mel_normalizer.inverse(mel, length)
                    output_mel = output_mel_normalizer.inverse(
                        output_mel, length)

                if args.vocoder_category == "wavernn":
                    for i in range(output_mel[0].shape[0]):
                        one_batch_output_mel = output_mel[0][i].unsqueeze(0)
                        one_batch_mel = mel[i].unsqueeze(0)
                        log_mel(
                            step,
                            one_batch_output_mel,
                            one_batch_mel,
                            att,
                            length,
                            log_save_dir,
                            args,
                            voc_model,
                        )
                else:
                    log_figure(step, output, spec_origin, att, length,
                               log_save_dir, args)
                out_log = "step {}: train_loss {:.4f}; spec_loss {:.4f};".format(
                    step, losses.avg, spec_losses.avg)

            if args.perceptual_loss > 0:
                out_log += "pe_loss {:.4f}; ".format(pe_losses.avg)
            if args.n_mels > 0:
                out_log += "mel_loss {:.4f}; ".format(mel_losses.avg)
                if args.double_mel_loss:
                    out_log += "dmel_loss {:.4f}; ".format(
                        double_mel_losses.avg)
            print("{} -- sum_time: {:.2f}s".format(out_log, (end - start)))

    info = {"loss": losses.avg, "spec_loss": spec_losses.avg}
    if args.perceptual_loss > 0:
        info["pe_loss"] = pe_losses.avg
    if args.n_mels > 0:
        info["mel_loss"] = mel_losses.avg
    return info
Esempio n. 4
0
def infer_predictor(args):
    """infer."""
    torch.cuda.set_device(args.gpu_id)
    logging.info(f"GPU {args.gpu_id} is used")
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare model
    model = RNN_Discriminator(
        embed_size=128,
        d_model=128,
        hidden_size=128,
        num_layers=2,
        n_specs=1025,
        singer_size=7,
        phone_size=43,
        simitone_size=59,
        dropout=0.1,
        bidirectional=True,
        device=device,
    )
    logging.info(f"{model}")
    model = model.to(device)
    logging.info(
        f"The model has {count_parameters(model):,} trainable parameters")

    # Load model weights
    logging.info(f"Loading pretrained weights from {args.model_file}")
    checkpoint = torch.load(args.model_file, map_location=device)
    state_dict = checkpoint["state_dict"]
    model_dict = model.state_dict()
    state_dict_new = {}
    para_list = []

    for k, v in state_dict.items():
        # assert k in model_dict
        if (k == "normalizer.mean" or k == "normalizer.std"
                or k == "mel_normalizer.mean" or k == "mel_normalizer.std"):
            continue
        if model_dict[k].size() == state_dict[k].size():
            state_dict_new[k] = v
        else:
            para_list.append(k)

    logging.info(f"Total {len(state_dict)} parameter sets, "
                 f"loaded {len(state_dict_new)} parameter set")

    if len(para_list) > 0:
        logging.warning(f"Not loading {para_list} because of different sizes")
    model.load_state_dict(state_dict_new)
    logging.info(f"Loaded checkpoint {args.model_file}")
    model = model.to(device)
    model.eval()

    # Decode
    test_set = SVSDataset(
        align_root_path=args.test_align,
        pitch_beat_root_path=args.test_pitch,
        wav_root_path=args.test_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        standard=args.standard,
        sing_quality=args.sing_quality,
        db_joint=args.db_joint,
        Hz2semitone=args.Hz2semitone,
        semitone_min=args.semitone_min,
        semitone_max=args.semitone_max,
        phone_shift_size=-1,
        semitone_shift=False,
    )
    collate_fn_svs = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
        args.n_mels,
        args.db_joint,
        False,  # random crop
        -1,  # crop_min_length
        args.Hz2semitone,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=1,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs,
        pin_memory=True,
    )

    criterion = nn.CrossEntropyLoss(reduction="sum")

    start_t_test = time.time()

    singer_losses = AverageMeter()
    phone_losses = AverageMeter()
    semitone_losses = AverageMeter()

    singer_count = AverageMeter()
    phone_count = AverageMeter()
    semitone_count = AverageMeter()

    with torch.no_grad():
        for (step, data_step) in enumerate(test_loader, 1):
            if args.db_joint:
                (
                    phone,
                    beat,
                    pitch,
                    spec,
                    real,
                    imag,
                    length,
                    chars,
                    char_len_list,
                    mel,
                    singer_id,
                    semitone,
                ) = data_step

                singer_id = np.array(singer_id).reshape(
                    np.shape(phone)[0], -1)  # [batch size, 1]
                singer_vec = singer_id.repeat(np.shape(phone)[1],
                                              axis=1)  # [batch size, length]
                singer_vec = torch.from_numpy(singer_vec).to(device)
                singer_id = torch.from_numpy(singer_id).to(device)
            else:
                (
                    phone,
                    beat,
                    pitch,
                    spec,
                    real,
                    imag,
                    length,
                    chars,
                    char_len_list,
                    mel,
                    semitone,
                ) = data_step

            phone = phone.to(device)
            beat = beat.to(device)
            pitch = pitch.to(device).float()
            if semitone is not None:
                semitone = semitone.to(device)
            spec = spec.to(device).float()
            mel = mel.to(device).float()
            real = real.to(device).float()
            imag = imag.to(device).float()
            length_mask = (length > 0).int().unsqueeze(2)
            length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float()
            length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
            length_mask = length_mask.to(device)
            length_mel_mask = length_mel_mask.to(device)
            length = length.to(device)
            char_len_list = char_len_list.to(device)

            if not args.use_asr_post:
                chars = chars.to(device)
                char_len_list = char_len_list.to(device)
            else:
                phone = phone.float()

            if args.Hz2semitone:
                pitch = semitone

            if args.normalize:
                sepc_normalizer = GlobalMVN(args.stats_file)
                mel_normalizer = GlobalMVN(args.stats_mel_file)
                spec, _ = sepc_normalizer(spec, length)
                mel, _ = mel_normalizer(mel, length)

            len_list, _ = torch.max(length, dim=1)  # [len1, len2, len3, ...]
            len_list = len_list.cpu().detach().numpy()

            singer_out, phone_out, semitone_out = model(spec, len_list)

            # calculate CrossEntropy loss (defination - reduction:sum)
            phone_loss = 0
            semitone_loss = 0
            phone_correct = 0
            semitone_correct = 0
            batch_size = np.shape(spec)[0]
            for i in range(batch_size):
                phone_i = phone[i, :len_list[i], :].view(-1)  # [valid seq len]
                phone_out_i = phone_out[
                    i, :len_list[i], :]  # [valid seq len, phone_size]
                phone_loss += criterion(phone_out_i, phone_i)

                _, phone_predict = torch.max(phone_out_i, dim=1)
                phone_correct += phone_predict.eq(phone_i).cpu().sum().numpy()

                semitone_i = semitone[i, :len_list[i], :].view(
                    -1)  # [valid seq len]
                semitone_out_i = semitone_out[
                    i, :len_list[i], :]  # [valid seq len, semitone_size]
                semitone_loss += criterion(semitone_out_i, semitone_i)

                _, semitone_predict = torch.max(semitone_out_i, dim=1)
                semitone_correct += semitone_predict.eq(
                    semitone_i).cpu().sum().numpy()

            singer_id = singer_id.view(-1)  # [batch size]
            _, singer_predict = torch.max(singer_out, dim=1)
            singer_correct = singer_predict.eq(singer_id).cpu().sum().numpy()

            phone_loss /= np.sum(len_list)
            semitone_loss /= np.sum(len_list)
            singer_loss = criterion(singer_out, singer_id) / batch_size

            # restore loss info
            singer_losses.update(singer_loss.item(), batch_size)
            phone_losses.update(phone_loss.item(), np.sum(len_list))
            semitone_losses.update(semitone_loss.item(), np.sum(len_list))

            singer_count.update(singer_correct.item() / batch_size, batch_size)
            phone_count.update(phone_correct.item() / np.sum(len_list),
                               np.sum(len_list))
            semitone_count.update(semitone_correct.item() / np.sum(len_list),
                                  np.sum(len_list))

            if step % 1 == 0:
                end = time.time()

                out_log = "step {}: loss {:.6f}, ".format(
                    step,
                    singer_losses.avg + phone_losses.avg + semitone_losses.avg)
                out_log += "\t singer_loss: {:.4f} ".format(singer_losses.avg)
                out_log += "phone_loss: {:.4f} ".format(phone_losses.avg)
                out_log += "semitone_loss: {:.4f} \n".format(
                    semitone_losses.avg)

                out_log += "\t singer_accuracy: {:.4f}% ".format(
                    singer_count.avg * 100)
                out_log += "phone_accuracy: {:.4f}% ".format(phone_count.avg *
                                                             100)
                out_log += "semitone_accuracy: {:.4f}% ".format(
                    semitone_count.avg * 100)

                print("{} -- sum_time: {:.2f}s".format(out_log,
                                                       (end - start_t_test)))

    end_t_test = time.time()

    out_log = "\nTest Stage: "
    out_log += "loss: {:.4f}, ".format(singer_losses.avg + phone_losses.avg +
                                       semitone_losses.avg)
    out_log += "singer_loss: {:.4f}, ".format(singer_losses.avg)
    out_log += "phone_loss: {:.4f}, semitone_loss: {:.4f} \n".format(
        phone_losses.avg,
        semitone_losses.avg,
    )
    out_log += "singer_accuracy: {:.4f}%, ".format(singer_count.avg * 100)
    out_log += "phone_accuracy: {:.4f}%, semitone_accuracy: {:.4f}% ".format(
        phone_count.avg * 100, semitone_count.avg * 100)
    logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))
Esempio n. 5
0
def augmentation(args, target_singer_id, output_path):

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    model, device = load_model(args)

    start_t_test = time.time()

    # Decode
    test_set = SVSDataset(
        align_root_path=args.test_align,
        pitch_beat_root_path=args.test_pitch,
        wav_root_path=args.test_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        standard=args.standard,
        sing_quality=args.sing_quality,
        db_joint=args.db_joint,
        Hz2semitone=args.Hz2semitone,
        semitone_min=args.semitone_min,
        semitone_max=args.semitone_max,
        phone_shift_size=-1,
        semitone_shift=False,
    )
    collate_fn_svs = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
        args.n_mels,
        args.db_joint,
        False,  # random crop
        -1,  # crop_min_length
        args.Hz2semitone,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=1,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs,
        pin_memory=True,
    )

    with torch.no_grad():
        for (
                step,
                data_step,
        ) in enumerate(test_loader, 1):
            if args.db_joint:
                (phone, beat, pitch, spec, real, imag, length, chars,
                 char_len_list, mel, singer_id, semitone, filename_list,
                 flag_filter_list) = data_step

            else:
                print(
                    "No support for augmentation with args.db_joint == False")
                quit()

            t_singer_id = np.array(target_singer_id).reshape(
                np.shape(phone)[0], -1)  # [batch size, 1]
            singer_vec = t_singer_id.repeat(np.shape(phone)[1],
                                            axis=1)  # [batch size, length]
            singer_vec = torch.from_numpy(singer_vec).to(device)

            phone = phone.to(device)
            beat = beat.to(device)
            pitch = pitch.to(device).float()
            if semitone is not None:
                semitone = semitone.to(device)
            spec = spec.to(device).float()
            mel = mel.to(device).float()
            real = real.to(device).float()
            imag = imag.to(device).float()
            length_mask = (length > 0).int().unsqueeze(2)
            length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float()
            length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
            length_mask = length_mask.to(device)
            length_mel_mask = length_mel_mask.to(device)
            length = length.to(device)
            char_len_list = char_len_list.to(device)

            if not args.use_asr_post:
                chars = chars.to(device)
                char_len_list = char_len_list.to(device)
            else:
                phone = phone.float()

            if args.Hz2semitone:
                pitch = semitone

            if args.model_type == "GLU_Transformer":
                if args.db_joint:
                    output, att, output_mel, output_mel2 = model(
                        chars,
                        phone,
                        pitch,
                        beat,
                        singer_vec,
                        pos_char=char_len_list,
                        pos_spec=length,
                    )
            elif args.model_type == "LSTM":
                if args.db_joint:
                    output, hidden, output_mel, output_mel2 = model(
                        phone, pitch, beat, singer_vec)

            elif args.model_type == "Comformer_full":
                if args.db_joint:
                    output, att, output_mel, output_mel2 = model(
                        chars,
                        phone,
                        pitch,
                        beat,
                        singer_vec,
                        pos_char=char_len_list,
                        pos_spec=length,
                    )

            if args.normalize:
                sepc_normalizer = GlobalMVN(args.stats_file)
                mel_normalizer = GlobalMVN(args.stats_mel_file)

            # normalize inverse stage
            if args.normalize and args.stats_file:
                output, _ = sepc_normalizer.inverse(output, length)

            # write wav
            output = output.cpu().detach().numpy()[0]
            length = np.max(length.cpu().detach().numpy()[0])
            output = output[:length]

            wav = spectrogram2wav(
                output,
                args.max_db,
                args.ref_db,
                args.preemphasis,
                args.power,
                args.sampling_rate,
                args.frame_shift,
                args.frame_length,
                args.nfft,
            )

            wr_fname = filename_list[0] + "-" + str(
                target_singer_id)  # batch_size = 1
            if librosa.__version__ < "0.8.0":
                librosa.output.write_wav(
                    os.path.join(output_path, "{}.wav".format(wr_fname)),
                    wav,
                    args.sampling_rate,
                )
            else:
                # librosa > 0.8 remove librosa.output.write_wav module
                sf.write(
                    os.path.join(output_path, "{}.wav".format(wr_fname)),
                    wav,
                    args.sampling_rate,
                    format="wav",
                    subtype="PCM_24",
                )

            end = time.time()
            out_log = os.path.join(output_path, "{}.wav".format(wr_fname))
            logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s")
Esempio n. 6
0
def infer(args):
    """infer."""
    torch.cuda.set_device(args.gpu_id)
    logging.info(f"GPU {args.gpu_id} is used")
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare model
    if args.model_type == "GLU_Transformer":
        model = GLU_TransformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            hidden_size=args.hidden_size,
            glu_num_layers=args.glu_num_layers,
            dropout=args.dropout,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            device=device,
        )
    elif args.model_type == "LSTM":
        model = LSTMSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            d_model=args.hidden_size,
            num_layers=args.num_rnn_layers,
            dropout=args.dropout,
            d_output=args.feat_dim,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            device=device,
            use_asr_post=args.use_asr_post,
        )
    elif args.model_type == "GRU_gs":
        model = GRUSVS_gs(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            d_model=args.hidden_size,
            num_layers=args.num_rnn_layers,
            dropout=args.dropout,
            d_output=args.feat_dim,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            device=device,
            use_asr_post=args.use_asr_post,
        )
    elif args.model_type == "PureTransformer":
        model = TransformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            hidden_size=args.hidden_size,
            glu_num_layers=args.glu_num_layers,
            dropout=args.dropout,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            device=device,
        )
    elif args.model_type == "Conformer":
        model = ConformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            enc_attention_dim=args.enc_attention_dim,
            enc_attention_heads=args.enc_attention_heads,
            enc_linear_units=args.enc_linear_units,
            enc_num_blocks=args.enc_num_blocks,
            enc_dropout_rate=args.enc_dropout_rate,
            enc_positional_dropout_rate=args.enc_positional_dropout_rate,
            enc_attention_dropout_rate=args.enc_attention_dropout_rate,
            enc_input_layer=args.enc_input_layer,
            enc_normalize_before=args.enc_normalize_before,
            enc_concat_after=args.enc_concat_after,
            enc_positionwise_layer_type=args.enc_positionwise_layer_type,
            enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size),
            enc_macaron_style=args.enc_macaron_style,
            enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
            enc_selfattention_layer_type=args.enc_selfattention_layer_type,
            enc_activation_type=args.enc_activation_type,
            enc_use_cnn_module=args.enc_use_cnn_module,
            enc_cnn_module_kernel=args.enc_cnn_module_kernel,
            enc_padding_idx=args.enc_padding_idx,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            dec_dropout=args.dec_dropout,
            device=device,
        )
    elif args.model_type == "Comformer_full":
        model = ConformerSVS_FULL(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            output_dim=args.feat_dim,
            n_mels=args.n_mels,
            enc_attention_dim=args.enc_attention_dim,
            enc_attention_heads=args.enc_attention_heads,
            enc_linear_units=args.enc_linear_units,
            enc_num_blocks=args.enc_num_blocks,
            enc_dropout_rate=args.enc_dropout_rate,
            enc_positional_dropout_rate=args.enc_positional_dropout_rate,
            enc_attention_dropout_rate=args.enc_attention_dropout_rate,
            enc_input_layer=args.enc_input_layer,
            enc_normalize_before=args.enc_normalize_before,
            enc_concat_after=args.enc_concat_after,
            enc_positionwise_layer_type=args.enc_positionwise_layer_type,
            enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size),
            enc_macaron_style=args.enc_macaron_style,
            enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
            enc_selfattention_layer_type=args.enc_selfattention_layer_type,
            enc_activation_type=args.enc_activation_type,
            enc_use_cnn_module=args.enc_use_cnn_module,
            enc_cnn_module_kernel=args.enc_cnn_module_kernel,
            enc_padding_idx=args.enc_padding_idx,
            dec_attention_dim=args.dec_attention_dim,
            dec_attention_heads=args.dec_attention_heads,
            dec_linear_units=args.dec_linear_units,
            dec_num_blocks=args.dec_num_blocks,
            dec_dropout_rate=args.dec_dropout_rate,
            dec_positional_dropout_rate=args.dec_positional_dropout_rate,
            dec_attention_dropout_rate=args.dec_attention_dropout_rate,
            dec_input_layer=args.dec_input_layer,
            dec_normalize_before=args.dec_normalize_before,
            dec_concat_after=args.dec_concat_after,
            dec_positionwise_layer_type=args.dec_positionwise_layer_type,
            dec_positionwise_conv_kernel_size=(args.dec_positionwise_conv_kernel_size),
            dec_macaron_style=args.dec_macaron_style,
            dec_pos_enc_layer_type=args.dec_pos_enc_layer_type,
            dec_selfattention_layer_type=args.dec_selfattention_layer_type,
            dec_activation_type=args.dec_activation_type,
            dec_use_cnn_module=args.dec_use_cnn_module,
            dec_cnn_module_kernel=args.dec_cnn_module_kernel,
            dec_padding_idx=args.dec_padding_idx,
            device=device,
        )
    else:
        raise ValueError("Not Support Model Type %s" % args.model_type)
    logging.info(f"{model}")
    logging.info(f"The model has {count_parameters(model):,} trainable parameters")

    # Load model weights
    logging.info(f"Loading pretrained weights from {args.model_file}")
    checkpoint = torch.load(args.model_file, map_location=device)
    state_dict = checkpoint["state_dict"]
    model_dict = model.state_dict()
    state_dict_new = {}
    para_list = []

    for k, v in state_dict.items():
        # assert k in model_dict
        if (
            k == "normalizer.mean"
            or k == "normalizer.std"
            or k == "mel_normalizer.mean"
            or k == "mel_normalizer.std"
        ):
            continue
        if model_dict[k].size() == state_dict[k].size():
            state_dict_new[k] = v
        else:
            para_list.append(k)

    logging.info(
        f"Total {len(state_dict)} parameter sets, "
        f"loaded {len(state_dict_new)} parameter set"
    )

    if len(para_list) > 0:
        logging.warning(f"Not loading {para_list} because of different sizes")
    model.load_state_dict(state_dict_new)
    logging.info(f"Loaded checkpoint {args.model_file}")
    model = model.to(device)
    model.eval()

    # Decode
    test_set = SVSDataset(
        align_root_path=args.test_align,
        pitch_beat_root_path=args.test_pitch,
        wav_root_path=args.test_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        standard=args.standard,
        sing_quality=args.sing_quality,
    )
    collate_fn_svs = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=1,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs,
        pin_memory=True,
    )

    if args.loss == "l1":
        criterion = MaskedLoss("l1")
    elif args.loss == "mse":
        criterion = MaskedLoss("mse")
    else:
        raise ValueError("Not Support Loss Type")

    losses = AverageMeter()
    spec_losses = AverageMeter()
    if args.perceptual_loss > 0:
        pe_losses = AverageMeter()
    if args.n_mels > 0:
        mel_losses = AverageMeter()
        mcd_metric = AverageMeter()
        f0_distortion_metric, vuv_error_metric = (
            AverageMeter(),
            AverageMeter(),
        )
        if args.double_mel_loss:
            double_mel_losses = AverageMeter()
    model.eval()

    if not os.path.exists(args.prediction_path):
        os.makedirs(args.prediction_path)

    f0_ground_truth_all = np.reshape(np.array([]), (-1, 1))
    f0_synthesis_all = np.reshape(np.array([]), (-1, 1))
    start_t_test = time.time()

    # preload vocoder model
    if args.vocoder_category == "wavernn":
        voc_model = WaveRNN(
            rnn_dims=512,
            fc_dims=512,
            bits=9,
            pad=2,
            upsample_factors=(
                5,
                5,
                11,
            ),
            feat_dims=80,
            compute_dims=128,
            res_out_dims=128,
            res_blocks=10,
            hop_length=275,  # 12.5ms - in line with Tacotron 2 paper
            sample_rate=22050,
            mode="MOL",
        ).to(device)

        voc_model.load("./weights/wavernn/latest_weights.pyt")

    with torch.no_grad():
        for (
            step,
            (
                phone,
                beat,
                pitch,
                spec,
                real,
                imag,
                length,
                chars,
                char_len_list,
                mel,
            ),
        ) in enumerate(test_loader, 1):
            # if step >= args.decode_sample:
            #     break
            phone = phone.to(device)
            beat = beat.to(device)
            pitch = pitch.to(device).float()
            spec = spec.to(device).float()
            mel = mel.to(device).float()
            real = real.to(device).float()
            imag = imag.to(device).float()
            length_mask = length.unsqueeze(2)
            length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float()
            length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
            length_mask = length_mask.to(device)
            length_mel_mask = length_mel_mask.to(device)
            length = length.to(device)
            char_len_list = char_len_list.to(device)

            if not args.use_asr_post:
                chars = chars.to(device)
                char_len_list = char_len_list.to(device)
            else:
                phone = phone.float()

            if args.model_type == "GLU_Transformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            elif args.model_type == "LSTM":
                output, hidden, output_mel, output_mel2 = model(phone, pitch, beat)
                att = None
            elif args.model_type == "GRU_gs":
                output, att, output_mel = model(spec, phone, pitch, beat, length, args)
                att = None
            elif args.model_type == "PureTransformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            elif args.model_type == "Conformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            elif args.model_type == "Comformer_full":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )

            spec_origin = spec.clone()
            # spec_origin = spec
            if args.normalize:
                sepc_normalizer = GlobalMVN(args.stats_file)
                mel_normalizer = GlobalMVN(args.stats_mel_file)
                spec, _ = sepc_normalizer(spec, length)
                mel, _ = mel_normalizer(mel, length)

            spec_loss = criterion(output, spec, length_mask)
            if args.n_mels > 0:
                mel_loss = criterion(output_mel, mel, length_mel_mask)
            else:
                mel_loss = 0

            final_loss = mel_loss + spec_loss

            losses.update(final_loss.item(), phone.size(0))
            spec_losses.update(spec_loss.item(), phone.size(0))
            if args.n_mels > 0:
                mel_losses.update(mel_loss.item(), phone.size(0))

            # normalize inverse stage
            if args.normalize and args.stats_file:
                output, _ = sepc_normalizer.inverse(output, length)
                # spec,_ = sepc_normalizer.inverse(spec,length)

            (mcd_value, length_sum,) = Metrics.Calculate_melcd_fromLinearSpectrum(
                output, spec_origin, length, args
            )
            (
                f0_distortion_value,
                voiced_frame_number_step,
                vuv_error_value,
                frame_number_step,
                f0_ground_truth_step,
                f0_synthesis_step,
            ) = Metrics.Calculate_f0RMSE_VUV_CORR_fromWav(
                output, spec_origin, length, args, "test"
            )
            f0_ground_truth_all = np.concatenate(
                (f0_ground_truth_all, f0_ground_truth_step), axis=0
            )
            f0_synthesis_all = np.concatenate(
                (f0_synthesis_all, f0_synthesis_step), axis=0
            )

            mcd_metric.update(mcd_value, length_sum)
            f0_distortion_metric.update(f0_distortion_value, voiced_frame_number_step)
            vuv_error_metric.update(vuv_error_value, frame_number_step)

            if step % 1 == 0:
                if args.vocoder_category == "griffin":
                    log_figure(
                        step,
                        output,
                        spec_origin,
                        att,
                        length,
                        args.prediction_path,
                        args,
                    )
                elif args.vocoder_category == "wavernn":
                    log_mel(
                        step,
                        output_mel,
                        spec_origin,
                        att,
                        length,
                        args.prediction_path,
                        args,
                        voc_model,
                    )
                out_log = (
                    "step {}:train_loss{:.4f};"
                    "spec_loss{:.4f};mcd_value{:.4f};".format(
                        step,
                        losses.avg,
                        spec_losses.avg,
                        mcd_metric.avg,
                    )
                )
                if args.perceptual_loss > 0:
                    out_log += " pe_loss {:.4f}; ".format(pe_losses.avg)
                if args.n_mels > 0:
                    out_log += " mel_loss {:.4f}; ".format(mel_losses.avg)
                    if args.double_mel_loss:
                        out_log += " dmel_loss {:.4f}; ".format(double_mel_losses.avg)
                end = time.time()
                logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s")

    end_t_test = time.time()

    out_log = "Test Stage: "
    out_log += "spec_loss: {:.4f} ".format(spec_losses.avg)
    if args.n_mels > 0:
        out_log += "mel_loss: {:.4f}, ".format(mel_losses.avg)
    # if args.perceptual_loss > 0:
    #     out_log += 'pe_loss: {:.4f}, '.format(train_info['pe_loss'])

    f0_corr = Metrics.compute_f0_corr(f0_ground_truth_all, f0_synthesis_all)

    out_log += "\n\t mcd_value {:.4f} dB ".format(mcd_metric.avg)
    out_log += (
        " f0_rmse_value {:.4f} Hz, "
        "vuv_error_value {:.4f} %, F0_CORR {:.4f}; ".format(
            np.sqrt(f0_distortion_metric.avg),
            vuv_error_metric.avg * 100,
            f0_corr,
        )
    )
    logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))