Exemplo n.º 1
0
def load_model(args):
    torch.cuda.set_device(args.gpu_id)
    logging.info(f"GPU {args.gpu_id} is used")
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.enabled = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare model
    if args.model_type == "GLU_Transformer":
        if args.db_joint:
            model = GLU_TransformerSVS_combine(
                phone_size=args.phone_size,
                singer_size=args.singer_size,
                embed_size=args.embedding_size,
                hidden_size=args.hidden_size,
                glu_num_layers=args.glu_num_layers,
                dropout=args.dropout,
                output_dim=args.feat_dim,
                dec_nhead=args.dec_nhead,
                dec_num_block=args.dec_num_block,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                local_gaussian=args.local_gaussian,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
        else:
            model = GLU_TransformerSVS(
                phone_size=args.phone_size,
                embed_size=args.embedding_size,
                hidden_size=args.hidden_size,
                glu_num_layers=args.glu_num_layers,
                dropout=args.dropout,
                output_dim=args.feat_dim,
                dec_nhead=args.dec_nhead,
                dec_num_block=args.dec_num_block,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                local_gaussian=args.local_gaussian,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
    elif args.model_type == "LSTM":
        if args.db_joint:
            model = LSTMSVS_combine(
                phone_size=args.phone_size,
                singer_size=args.singer_size,
                embed_size=args.embedding_size,
                d_model=args.hidden_size,
                num_layers=args.num_rnn_layers,
                dropout=args.dropout,
                d_output=args.feat_dim,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
                use_asr_post=args.use_asr_post,
            )
        else:
            model = LSTMSVS(
                phone_size=args.phone_size,
                embed_size=args.embedding_size,
                d_model=args.hidden_size,
                num_layers=args.num_rnn_layers,
                dropout=args.dropout,
                d_output=args.feat_dim,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
                use_asr_post=args.use_asr_post,
            )
    elif args.model_type == "GRU_gs":
        model = GRUSVS_gs(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            d_model=args.hidden_size,
            num_layers=args.num_rnn_layers,
            dropout=args.dropout,
            d_output=args.feat_dim,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            device=device,
            use_asr_post=args.use_asr_post,
        )
    elif args.model_type == "PureTransformer":
        model = TransformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            hidden_size=args.hidden_size,
            glu_num_layers=args.glu_num_layers,
            dropout=args.dropout,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            device=device,
        )
    elif args.model_type == "Conformer":
        model = ConformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            enc_attention_dim=args.enc_attention_dim,
            enc_attention_heads=args.enc_attention_heads,
            enc_linear_units=args.enc_linear_units,
            enc_num_blocks=args.enc_num_blocks,
            enc_dropout_rate=args.enc_dropout_rate,
            enc_positional_dropout_rate=args.enc_positional_dropout_rate,
            enc_attention_dropout_rate=args.enc_attention_dropout_rate,
            enc_input_layer=args.enc_input_layer,
            enc_normalize_before=args.enc_normalize_before,
            enc_concat_after=args.enc_concat_after,
            enc_positionwise_layer_type=args.enc_positionwise_layer_type,
            enc_positionwise_conv_kernel_size=(
                args.enc_positionwise_conv_kernel_size),
            enc_macaron_style=args.enc_macaron_style,
            enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
            enc_selfattention_layer_type=args.enc_selfattention_layer_type,
            enc_activation_type=args.enc_activation_type,
            enc_use_cnn_module=args.enc_use_cnn_module,
            enc_cnn_module_kernel=args.enc_cnn_module_kernel,
            enc_padding_idx=args.enc_padding_idx,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            dec_dropout=args.dec_dropout,
            Hz2semitone=args.Hz2semitone,
            semitone_size=args.semitone_size,
            device=device,
        )
    elif args.model_type == "Comformer_full":
        if args.db_joint:
            model = ConformerSVS_FULL_combine(
                phone_size=args.phone_size,
                singer_size=args.singer_size,
                embed_size=args.embedding_size,
                output_dim=args.feat_dim,
                n_mels=args.n_mels,
                enc_attention_dim=args.enc_attention_dim,
                enc_attention_heads=args.enc_attention_heads,
                enc_linear_units=args.enc_linear_units,
                enc_num_blocks=args.enc_num_blocks,
                enc_dropout_rate=args.enc_dropout_rate,
                enc_positional_dropout_rate=args.enc_positional_dropout_rate,
                enc_attention_dropout_rate=args.enc_attention_dropout_rate,
                enc_input_layer=args.enc_input_layer,
                enc_normalize_before=args.enc_normalize_before,
                enc_concat_after=args.enc_concat_after,
                enc_positionwise_layer_type=args.enc_positionwise_layer_type,
                enc_positionwise_conv_kernel_size=(
                    args.enc_positionwise_conv_kernel_size),
                enc_macaron_style=args.enc_macaron_style,
                enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
                enc_selfattention_layer_type=args.enc_selfattention_layer_type,
                enc_activation_type=args.enc_activation_type,
                enc_use_cnn_module=args.enc_use_cnn_module,
                enc_cnn_module_kernel=args.enc_cnn_module_kernel,
                enc_padding_idx=args.enc_padding_idx,
                dec_attention_dim=args.dec_attention_dim,
                dec_attention_heads=args.dec_attention_heads,
                dec_linear_units=args.dec_linear_units,
                dec_num_blocks=args.dec_num_blocks,
                dec_dropout_rate=args.dec_dropout_rate,
                dec_positional_dropout_rate=args.dec_positional_dropout_rate,
                dec_attention_dropout_rate=args.dec_attention_dropout_rate,
                dec_input_layer=args.dec_input_layer,
                dec_normalize_before=args.dec_normalize_before,
                dec_concat_after=args.dec_concat_after,
                dec_positionwise_layer_type=args.dec_positionwise_layer_type,
                dec_positionwise_conv_kernel_size=(
                    args.dec_positionwise_conv_kernel_size),
                dec_macaron_style=args.dec_macaron_style,
                dec_pos_enc_layer_type=args.dec_pos_enc_layer_type,
                dec_selfattention_layer_type=args.dec_selfattention_layer_type,
                dec_activation_type=args.dec_activation_type,
                dec_use_cnn_module=args.dec_use_cnn_module,
                dec_cnn_module_kernel=args.dec_cnn_module_kernel,
                dec_padding_idx=args.dec_padding_idx,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
        else:
            model = ConformerSVS_FULL(
                phone_size=args.phone_size,
                embed_size=args.embedding_size,
                output_dim=args.feat_dim,
                n_mels=args.n_mels,
                enc_attention_dim=args.enc_attention_dim,
                enc_attention_heads=args.enc_attention_heads,
                enc_linear_units=args.enc_linear_units,
                enc_num_blocks=args.enc_num_blocks,
                enc_dropout_rate=args.enc_dropout_rate,
                enc_positional_dropout_rate=args.enc_positional_dropout_rate,
                enc_attention_dropout_rate=args.enc_attention_dropout_rate,
                enc_input_layer=args.enc_input_layer,
                enc_normalize_before=args.enc_normalize_before,
                enc_concat_after=args.enc_concat_after,
                enc_positionwise_layer_type=args.enc_positionwise_layer_type,
                enc_positionwise_conv_kernel_size=(
                    args.enc_positionwise_conv_kernel_size),
                enc_macaron_style=args.enc_macaron_style,
                enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
                enc_selfattention_layer_type=args.enc_selfattention_layer_type,
                enc_activation_type=args.enc_activation_type,
                enc_use_cnn_module=args.enc_use_cnn_module,
                enc_cnn_module_kernel=args.enc_cnn_module_kernel,
                enc_padding_idx=args.enc_padding_idx,
                dec_attention_dim=args.dec_attention_dim,
                dec_attention_heads=args.dec_attention_heads,
                dec_linear_units=args.dec_linear_units,
                dec_num_blocks=args.dec_num_blocks,
                dec_dropout_rate=args.dec_dropout_rate,
                dec_positional_dropout_rate=args.dec_positional_dropout_rate,
                dec_attention_dropout_rate=args.dec_attention_dropout_rate,
                dec_input_layer=args.dec_input_layer,
                dec_normalize_before=args.dec_normalize_before,
                dec_concat_after=args.dec_concat_after,
                dec_positionwise_layer_type=args.dec_positionwise_layer_type,
                dec_positionwise_conv_kernel_size=(
                    args.dec_positionwise_conv_kernel_size),
                dec_macaron_style=args.dec_macaron_style,
                dec_pos_enc_layer_type=args.dec_pos_enc_layer_type,
                dec_selfattention_layer_type=args.dec_selfattention_layer_type,
                dec_activation_type=args.dec_activation_type,
                dec_use_cnn_module=args.dec_use_cnn_module,
                dec_cnn_module_kernel=args.dec_cnn_module_kernel,
                dec_padding_idx=args.dec_padding_idx,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
    else:
        raise ValueError("Not Support Model Type %s" % args.model_type)
    logging.info(f"{model}")
    logging.info(
        f"The model has {count_parameters(model):,} trainable parameters")

    # Load model weights
    logging.info(f"Loading pretrained weights from {args.model_file}")
    checkpoint = torch.load(args.model_file, map_location=device)
    state_dict = checkpoint["state_dict"]
    model_dict = model.state_dict()
    state_dict_new = {}
    para_list = []

    for k, v in state_dict.items():
        # assert k in model_dict
        if (k == "normalizer.mean" or k == "normalizer.std"
                or k == "mel_normalizer.mean" or k == "mel_normalizer.std"):
            continue
        if model_dict[k].size() == state_dict[k].size():
            state_dict_new[k] = v
        else:
            para_list.append(k)

    logging.info(f"Total {len(state_dict)} parameter sets, "
                 f"loaded {len(state_dict_new)} parameter set")

    if len(para_list) > 0:
        logging.warning(f"Not loading {para_list} because of different sizes")
    model.load_state_dict(state_dict_new)
    logging.info(f"Loaded checkpoint {args.model_file}")
    model = model.to(device)
    model.eval()

    return model, device
Exemplo n.º 2
0
def train(args):
    """train."""
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if torch.cuda.is_available() and args.auto_select_gpu is True:
        cvd = use_single_gpu()
        logging.info(f"GPU {cvd} is used")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # torch.backends.cudnn.enabled = False
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    elif torch.cuda.is_available() and args.auto_select_gpu is False:
        torch.cuda.set_device(args.gpu_id)
        logging.info(f"GPU {args.gpu_id} is used")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # torch.backends.cudnn.enabled = False
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        device = torch.device("cpu")
        logging.info("Warning: CPU is used")

    train_set = SVSDataset(
        align_root_path=args.train_align,
        pitch_beat_root_path=args.train_pitch,
        wav_root_path=args.train_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        sing_quality=args.sing_quality,
        standard=args.standard,
        db_joint=args.db_joint,
        Hz2semitone=args.Hz2semitone,
        semitone_min=args.semitone_min,
        semitone_max=args.semitone_max,
        phone_shift_size=args.phone_shift_size,
        semitone_shift=args.semitone_shift,
    )

    dev_set = SVSDataset(
        align_root_path=args.val_align,
        pitch_beat_root_path=args.val_pitch,
        wav_root_path=args.val_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        sing_quality=args.sing_quality,
        standard=args.standard,
        db_joint=args.db_joint,
        Hz2semitone=args.Hz2semitone,
        semitone_min=args.semitone_min,
        semitone_max=args.semitone_max,
        phone_shift_size=-1,
        semitone_shift=False,
    )
    collate_fn_svs_train = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
        args.n_mels,
        args.db_joint,
        args.random_crop,
        args.crop_min_length,
        args.Hz2semitone,
    )
    collate_fn_svs_val = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
        args.n_mels,
        args.db_joint,
        False,  # random crop
        -1,  # crop_min_length
        args.Hz2semitone,
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batchsize,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs_train,
        pin_memory=True,
    )
    dev_loader = torch.utils.data.DataLoader(
        dataset=dev_set,
        batch_size=args.batchsize,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs_val,
        pin_memory=True,
    )

    assert (args.feat_dim == dev_set[0]["spec"].shape[1]
            or args.feat_dim == dev_set[0]["mel"].shape[1])

    if args.collect_stats:
        collect_stats(train_loader, args)
        logging.info("collect_stats finished !")
        quit()
    # prepare model
    if args.model_type == "GLU_Transformer":
        if args.db_joint:
            model = GLU_TransformerSVS_combine(
                phone_size=args.phone_size,
                singer_size=args.singer_size,
                embed_size=args.embedding_size,
                hidden_size=args.hidden_size,
                glu_num_layers=args.glu_num_layers,
                dropout=args.dropout,
                output_dim=args.feat_dim,
                dec_nhead=args.dec_nhead,
                dec_num_block=args.dec_num_block,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                local_gaussian=args.local_gaussian,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
        else:
            model = GLU_TransformerSVS(
                phone_size=args.phone_size,
                embed_size=args.embedding_size,
                hidden_size=args.hidden_size,
                glu_num_layers=args.glu_num_layers,
                dropout=args.dropout,
                output_dim=args.feat_dim,
                dec_nhead=args.dec_nhead,
                dec_num_block=args.dec_num_block,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                local_gaussian=args.local_gaussian,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
    elif args.model_type == "LSTM":
        if args.db_joint:
            model = LSTMSVS_combine(
                phone_size=args.phone_size,
                singer_size=args.singer_size,
                embed_size=args.embedding_size,
                d_model=args.hidden_size,
                num_layers=args.num_rnn_layers,
                dropout=args.dropout,
                d_output=args.feat_dim,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
                use_asr_post=args.use_asr_post,
            )
        else:
            model = LSTMSVS(
                phone_size=args.phone_size,
                embed_size=args.embedding_size,
                d_model=args.hidden_size,
                num_layers=args.num_rnn_layers,
                dropout=args.dropout,
                d_output=args.feat_dim,
                n_mels=args.n_mels,
                double_mel_loss=args.double_mel_loss,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
                use_asr_post=args.use_asr_post,
            )
    elif args.model_type == "GRU_gs":
        model = GRUSVS_gs(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            d_model=args.hidden_size,
            num_layers=args.num_rnn_layers,
            dropout=args.dropout,
            d_output=args.feat_dim,
            n_mels=args.n_mels,
            device=device,
            use_asr_post=args.use_asr_post,
        )
    elif args.model_type == "PureTransformer":
        model = TransformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            hidden_size=args.hidden_size,
            glu_num_layers=args.glu_num_layers,
            dropout=args.dropout,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            device=device,
        )
    elif args.model_type == "Conformer":
        model = ConformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            enc_attention_dim=args.enc_attention_dim,
            enc_attention_heads=args.enc_attention_heads,
            enc_linear_units=args.enc_linear_units,
            enc_num_blocks=args.enc_num_blocks,
            enc_dropout_rate=args.enc_dropout_rate,
            enc_positional_dropout_rate=args.enc_positional_dropout_rate,
            enc_attention_dropout_rate=args.enc_attention_dropout_rate,
            enc_input_layer=args.enc_input_layer,
            enc_normalize_before=args.enc_normalize_before,
            enc_concat_after=args.enc_concat_after,
            enc_positionwise_layer_type=args.enc_positionwise_layer_type,
            enc_positionwise_conv_kernel_size=(
                args.enc_positionwise_conv_kernel_size),
            enc_macaron_style=args.enc_macaron_style,
            enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
            enc_selfattention_layer_type=args.enc_selfattention_layer_type,
            enc_activation_type=args.enc_activation_type,
            enc_use_cnn_module=args.enc_use_cnn_module,
            enc_cnn_module_kernel=args.enc_cnn_module_kernel,
            enc_padding_idx=args.enc_padding_idx,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            dec_dropout=args.dec_dropout,
            Hz2semitone=args.Hz2semitone,
            semitone_size=args.semitone_size,
            device=device,
        )
    elif args.model_type == "Comformer_full":
        if args.db_joint:
            model = ConformerSVS_FULL_combine(
                phone_size=args.phone_size,
                singer_size=args.singer_size,
                embed_size=args.embedding_size,
                output_dim=args.feat_dim,
                n_mels=args.n_mels,
                enc_attention_dim=args.enc_attention_dim,
                enc_attention_heads=args.enc_attention_heads,
                enc_linear_units=args.enc_linear_units,
                enc_num_blocks=args.enc_num_blocks,
                enc_dropout_rate=args.enc_dropout_rate,
                enc_positional_dropout_rate=args.enc_positional_dropout_rate,
                enc_attention_dropout_rate=args.enc_attention_dropout_rate,
                enc_input_layer=args.enc_input_layer,
                enc_normalize_before=args.enc_normalize_before,
                enc_concat_after=args.enc_concat_after,
                enc_positionwise_layer_type=args.enc_positionwise_layer_type,
                enc_positionwise_conv_kernel_size=(
                    args.enc_positionwise_conv_kernel_size),
                enc_macaron_style=args.enc_macaron_style,
                enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
                enc_selfattention_layer_type=args.enc_selfattention_layer_type,
                enc_activation_type=args.enc_activation_type,
                enc_use_cnn_module=args.enc_use_cnn_module,
                enc_cnn_module_kernel=args.enc_cnn_module_kernel,
                enc_padding_idx=args.enc_padding_idx,
                dec_attention_dim=args.dec_attention_dim,
                dec_attention_heads=args.dec_attention_heads,
                dec_linear_units=args.dec_linear_units,
                dec_num_blocks=args.dec_num_blocks,
                dec_dropout_rate=args.dec_dropout_rate,
                dec_positional_dropout_rate=args.dec_positional_dropout_rate,
                dec_attention_dropout_rate=args.dec_attention_dropout_rate,
                dec_input_layer=args.dec_input_layer,
                dec_normalize_before=args.dec_normalize_before,
                dec_concat_after=args.dec_concat_after,
                dec_positionwise_layer_type=args.dec_positionwise_layer_type,
                dec_positionwise_conv_kernel_size=(
                    args.dec_positionwise_conv_kernel_size),
                dec_macaron_style=args.dec_macaron_style,
                dec_pos_enc_layer_type=args.dec_pos_enc_layer_type,
                dec_selfattention_layer_type=args.dec_selfattention_layer_type,
                dec_activation_type=args.dec_activation_type,
                dec_use_cnn_module=args.dec_use_cnn_module,
                dec_cnn_module_kernel=args.dec_cnn_module_kernel,
                dec_padding_idx=args.dec_padding_idx,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )
        else:
            model = ConformerSVS_FULL(
                phone_size=args.phone_size,
                embed_size=args.embedding_size,
                output_dim=args.feat_dim,
                n_mels=args.n_mels,
                enc_attention_dim=args.enc_attention_dim,
                enc_attention_heads=args.enc_attention_heads,
                enc_linear_units=args.enc_linear_units,
                enc_num_blocks=args.enc_num_blocks,
                enc_dropout_rate=args.enc_dropout_rate,
                enc_positional_dropout_rate=args.enc_positional_dropout_rate,
                enc_attention_dropout_rate=args.enc_attention_dropout_rate,
                enc_input_layer=args.enc_input_layer,
                enc_normalize_before=args.enc_normalize_before,
                enc_concat_after=args.enc_concat_after,
                enc_positionwise_layer_type=args.enc_positionwise_layer_type,
                enc_positionwise_conv_kernel_size=(
                    args.enc_positionwise_conv_kernel_size),
                enc_macaron_style=args.enc_macaron_style,
                enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
                enc_selfattention_layer_type=args.enc_selfattention_layer_type,
                enc_activation_type=args.enc_activation_type,
                enc_use_cnn_module=args.enc_use_cnn_module,
                enc_cnn_module_kernel=args.enc_cnn_module_kernel,
                enc_padding_idx=args.enc_padding_idx,
                dec_attention_dim=args.dec_attention_dim,
                dec_attention_heads=args.dec_attention_heads,
                dec_linear_units=args.dec_linear_units,
                dec_num_blocks=args.dec_num_blocks,
                dec_dropout_rate=args.dec_dropout_rate,
                dec_positional_dropout_rate=args.dec_positional_dropout_rate,
                dec_attention_dropout_rate=args.dec_attention_dropout_rate,
                dec_input_layer=args.dec_input_layer,
                dec_normalize_before=args.dec_normalize_before,
                dec_concat_after=args.dec_concat_after,
                dec_positionwise_layer_type=args.dec_positionwise_layer_type,
                dec_positionwise_conv_kernel_size=(
                    args.dec_positionwise_conv_kernel_size),
                dec_macaron_style=args.dec_macaron_style,
                dec_pos_enc_layer_type=args.dec_pos_enc_layer_type,
                dec_selfattention_layer_type=args.dec_selfattention_layer_type,
                dec_activation_type=args.dec_activation_type,
                dec_use_cnn_module=args.dec_use_cnn_module,
                dec_cnn_module_kernel=args.dec_cnn_module_kernel,
                dec_padding_idx=args.dec_padding_idx,
                Hz2semitone=args.Hz2semitone,
                semitone_size=args.semitone_size,
                device=device,
            )

    elif args.model_type == "USTC_DAR":
        model = USTC_SVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            middle_dim_fc=args.middle_dim_fc,
            output_dim=args.feat_dim,
            multi_history_num=args.multi_history_num,
            middle_dim_prenet=args.middle_dim_prenet,
            n_blocks_prenet=args.n_blocks_prenet,
            n_heads_prenet=args.n_heads_prenet,
            kernel_size_prenet=args.kernel_size_prenet,
            bi_d_model=args.bi_d_model,
            bi_num_layers=args.bi_num_layers,
            uni_d_model=args.uni_d_model,
            uni_num_layers=args.uni_num_layers,
            dropout=args.dropout,
            feedbackLink_drop_rate=args.feedbackLink_drop_rate,
            device=device,
        )

    else:
        raise ValueError("Not Support Model Type %s" % args.model_type)
    logging.info(f"{model}")
    model = model.to(device)
    logging.info(
        f"The model has {count_parameters(model):,} trainable parameters")

    model_load_dir = ""
    pretrain_encoder_dir = ""
    start_epoch = 1  # FIX ME
    if args.pretrain_encoder != "":
        pretrain_encoder_dir = args.pretrain_encoder
    if args.initmodel != "":
        model_load_dir = args.initmodel
    if args.resume and os.path.exists(args.model_save_dir):
        checks = os.listdir(args.model_save_dir)
        start_epoch = max(
            list(
                map(
                    lambda x: int(x.split(".")[0].split("_")[-1])
                    if x.endswith("pth.tar") else -1,
                    checks,
                )))
        model_temp_load_dir = "{}/epoch_loss_{}.pth.tar".format(
            args.model_save_dir, start_epoch)
        if start_epoch < 0:
            model_load_dir = ""
        elif os.path.isfile(model_temp_load_dir):
            model_load_dir = model_temp_load_dir
        else:
            model_load_dir = "{}/epoch_spec_loss_{}.pth.tar".format(
                args.model_save_dir, start_epoch)

    # load encoder parm from Transformer-TTS
    if pretrain_encoder_dir != "":
        pretrain = torch.load(pretrain_encoder_dir, map_location=device)
        pretrain_dict = pretrain["model"]
        model_dict = model.state_dict()
        state_dict_new = {}
        para_list = []
        i = 0
        for k, v in pretrain_dict.items():
            k_new = k[7:]
            if (k_new in model_dict
                    and model_dict[k_new].size() == pretrain_dict[k].size()):
                i += 1
                state_dict_new[k_new] = v
            model_dict.update(state_dict_new)
        model.load_state_dict(model_dict)
        logging.info(f"Load {i} layers total. Load pretrain encoder success !")

    # load weights for pre-trained model
    if model_load_dir != "":
        logging.info(f"Model Start to Load, dir: {model_load_dir}")
        model_load = torch.load(model_load_dir, map_location=device)
        loading_dict = model_load["state_dict"]
        model_dict = model.state_dict()
        state_dict_new = {}
        para_list = []
        for k, v in loading_dict.items():
            # assert k in model_dict
            if (k == "normalizer.mean" or k == "normalizer.std"
                    or k == "mel_normalizer.mean"
                    or k == "mel_normalizer.std"):
                continue
            if model_dict[k].size() == loading_dict[k].size():
                state_dict_new[k] = v
            else:
                para_list.append(k)
        logging.info(f"Total {len(loading_dict)} parameter sets, "
                     f"Loaded {len(state_dict_new)} parameter sets")
        if len(para_list) > 0:
            logging.warning("Not loading {} because of different sizes".format(
                ", ".join(para_list)))
        model_dict.update(state_dict_new)
        model.load_state_dict(model_dict)
        logging.info(f"Loaded checkpoint {args.initmodel}")

    # setup optimizer
    if args.optimizer == "noam":
        optimizer = ScheduledOptim(
            torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.9, 0.98),
                             eps=1e-09),
            args.hidden_size,
            args.noam_warmup_steps,
            args.noam_scale,
        )
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-09)
        if args.scheduler == "OneCycleLR":
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=args.lr,
                steps_per_epoch=len(train_loader),
                epochs=args.max_epochs,
            )
        elif args.scheduler == "ReduceLROnPlateau":
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, "min", verbose=True, patience=10, factor=0.5)
        elif args.scheduler == "ExponentialLR":
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                               verbose=True,
                                                               gamma=0.9886)
    else:
        raise ValueError("Not Support Optimizer")

    # Setup tensorborad logger
    if args.use_tfboard:
        from tensorboardX import SummaryWriter

        logger = SummaryWriter("{}/log".format(args.model_save_dir))
    else:
        logger = None

    if args.loss == "l1":
        loss = MaskedLoss("l1", mask_free=args.mask_free)
    elif args.loss == "mse":
        loss = MaskedLoss("mse", mask_free=args.mask_free)
    else:
        raise ValueError("Not Support Loss Type")

    if args.perceptual_loss > 0:
        win_length = int(args.sampling_rate * args.frame_length)
        psd_dict, bark_num = cal_psd2bark_dict(fs=args.sampling_rate,
                                               win_len=win_length)
        sf = cal_spread_function(bark_num)
        loss_perceptual_entropy = PerceptualEntropy(bark_num, sf,
                                                    args.sampling_rate,
                                                    win_length, psd_dict)
    else:
        loss_perceptual_entropy = None

    # Training
    total_loss_epoch_to_save = {}
    total_loss_counter = 0
    spec_loss_epoch_to_save = {}
    spec_loss_counter = 0
    total_learning_step = 0

    # args.num_saved_model = 5

    # preload vocoder model
    voc_model = []
    if args.vocoder_category == "wavernn":
        logging.info("load voc_model from {}".format(args.wavernn_voc_model))
        voc_model = WaveRNN(
            rnn_dims=args.voc_rnn_dims,
            fc_dims=args.voc_fc_dims,
            bits=args.voc_bits,
            pad=args.voc_pad,
            upsample_factors=(
                args.voc_upsample_factors_0,
                args.voc_upsample_factors_1,
                args.voc_upsample_factors_2,
            ),
            feat_dims=args.n_mels,
            compute_dims=args.voc_compute_dims,
            res_out_dims=args.voc_res_out_dims,
            res_blocks=args.voc_res_blocks,
            hop_length=args.hop_length,
            sample_rate=args.sampling_rate,
            mode=args.voc_mode,
        ).to(device)

        voc_model.load(args.wavernn_voc_model)

    for epoch in range(start_epoch + 1, 1 + args.max_epochs):
        """Train Stage"""
        start_t_train = time.time()
        train_info = train_one_epoch(
            train_loader,
            model,
            device,
            optimizer,
            loss,
            loss_perceptual_entropy,
            epoch,
            args,
            voc_model,
        )
        end_t_train = time.time()

        out_log = "Train epoch: {:04d}, ".format(epoch)
        if args.optimizer == "noam":
            out_log += "lr: {:.6f}, ".format(
                optimizer._optimizer.param_groups[0]["lr"])
        elif args.optimizer == "adam":
            out_log += "lr: {:.6f}, ".format(optimizer.param_groups[0]["lr"])

        if args.vocoder_category == "wavernn":
            out_log += "loss: {:.4f} ".format(train_info["loss"])
        else:
            out_log += "loss: {:.4f}, spec_loss: {:.4f} ".format(
                train_info["loss"], train_info["spec_loss"])

        if args.n_mels > 0:
            out_log += "mel_loss: {:.4f}, ".format(train_info["mel_loss"])
        if args.perceptual_loss > 0:
            out_log += "pe_loss: {:.4f}, ".format(train_info["pe_loss"])
        logging.info("{} time: {:.2f}s".format(out_log,
                                               end_t_train - start_t_train))
        """Dev Stage"""
        torch.backends.cudnn.enabled = False  # 莫名的bug,关掉才可以跑

        # start_t_dev = time.time()
        dev_info = validate(
            dev_loader,
            model,
            device,
            loss,
            loss_perceptual_entropy,
            epoch,
            args,
            voc_model,
        )
        end_t_dev = time.time()

        dev_log = "Dev epoch: {:04d}, loss: {:.4f}, spec_loss: {:.4f}, ".format(
            epoch, dev_info["loss"], dev_info["spec_loss"])
        dev_log += "mcd_value: {:.4f}, ".format(dev_info["mcd_value"])
        if args.n_mels > 0:
            dev_log += "mel_loss: {:.4f}, ".format(dev_info["mel_loss"])
        if args.perceptual_loss > 0:
            dev_log += "pe_loss: {:.4f}, ".format(dev_info["pe_loss"])
        logging.info("{} time: {:.2f}s".format(dev_log,
                                               end_t_dev - start_t_train))

        sys.stdout.flush()

        torch.backends.cudnn.enabled = True

        if args.scheduler == "OneCycleLR":
            scheduler.step()
        elif args.scheduler == "ReduceLROnPlateau":
            scheduler.step(dev_info["loss"])
        elif args.scheduler == "ExponentialLR":
            before = total_learning_step // args.lr_decay_learning_steps
            total_learning_step += len(train_loader)
            after = total_learning_step // args.lr_decay_learning_steps
            if after > before:  # decay per 250 learning steps
                scheduler.step()
        """Save model Stage"""
        if not os.path.exists(args.model_save_dir):
            os.makedirs(args.model_save_dir)

        (total_loss_counter, total_loss_epoch_to_save) = Auto_save_model(
            args,
            epoch,
            model,
            optimizer,
            train_info,
            dev_info,
            logger,
            total_loss_counter,
            total_loss_epoch_to_save,
            save_loss_select="loss",
        )

        if (dev_info["spec_loss"] !=
                0):  # spec_loss 有意义时再存模型,比如 USTC DAR model 不需要计算线性谱spec loss
            (spec_loss_counter, spec_loss_epoch_to_save) = Auto_save_model(
                args,
                epoch,
                model,
                optimizer,
                train_info,
                dev_info,
                logger,
                spec_loss_counter,
                spec_loss_epoch_to_save,
                save_loss_select="spec_loss",
            )

    if args.use_tfboard:
        logger.close()
Exemplo n.º 3
0
def infer(args):
    """infer."""
    torch.cuda.set_device(args.gpu_id)
    logging.info(f"GPU {args.gpu_id} is used")
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # prepare model
    if args.model_type == "GLU_Transformer":
        model = GLU_TransformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            hidden_size=args.hidden_size,
            glu_num_layers=args.glu_num_layers,
            dropout=args.dropout,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            device=device,
        )
    elif args.model_type == "LSTM":
        model = LSTMSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            d_model=args.hidden_size,
            num_layers=args.num_rnn_layers,
            dropout=args.dropout,
            d_output=args.feat_dim,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            device=device,
            use_asr_post=args.use_asr_post,
        )
    elif args.model_type == "GRU_gs":
        model = GRUSVS_gs(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            d_model=args.hidden_size,
            num_layers=args.num_rnn_layers,
            dropout=args.dropout,
            d_output=args.feat_dim,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            device=device,
            use_asr_post=args.use_asr_post,
        )
    elif args.model_type == "PureTransformer":
        model = TransformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            hidden_size=args.hidden_size,
            glu_num_layers=args.glu_num_layers,
            dropout=args.dropout,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            device=device,
        )
    elif args.model_type == "Conformer":
        model = ConformerSVS(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            enc_attention_dim=args.enc_attention_dim,
            enc_attention_heads=args.enc_attention_heads,
            enc_linear_units=args.enc_linear_units,
            enc_num_blocks=args.enc_num_blocks,
            enc_dropout_rate=args.enc_dropout_rate,
            enc_positional_dropout_rate=args.enc_positional_dropout_rate,
            enc_attention_dropout_rate=args.enc_attention_dropout_rate,
            enc_input_layer=args.enc_input_layer,
            enc_normalize_before=args.enc_normalize_before,
            enc_concat_after=args.enc_concat_after,
            enc_positionwise_layer_type=args.enc_positionwise_layer_type,
            enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size),
            enc_macaron_style=args.enc_macaron_style,
            enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
            enc_selfattention_layer_type=args.enc_selfattention_layer_type,
            enc_activation_type=args.enc_activation_type,
            enc_use_cnn_module=args.enc_use_cnn_module,
            enc_cnn_module_kernel=args.enc_cnn_module_kernel,
            enc_padding_idx=args.enc_padding_idx,
            output_dim=args.feat_dim,
            dec_nhead=args.dec_nhead,
            dec_num_block=args.dec_num_block,
            n_mels=args.n_mels,
            double_mel_loss=args.double_mel_loss,
            local_gaussian=args.local_gaussian,
            dec_dropout=args.dec_dropout,
            device=device,
        )
    elif args.model_type == "Comformer_full":
        model = ConformerSVS_FULL(
            phone_size=args.phone_size,
            embed_size=args.embedding_size,
            output_dim=args.feat_dim,
            n_mels=args.n_mels,
            enc_attention_dim=args.enc_attention_dim,
            enc_attention_heads=args.enc_attention_heads,
            enc_linear_units=args.enc_linear_units,
            enc_num_blocks=args.enc_num_blocks,
            enc_dropout_rate=args.enc_dropout_rate,
            enc_positional_dropout_rate=args.enc_positional_dropout_rate,
            enc_attention_dropout_rate=args.enc_attention_dropout_rate,
            enc_input_layer=args.enc_input_layer,
            enc_normalize_before=args.enc_normalize_before,
            enc_concat_after=args.enc_concat_after,
            enc_positionwise_layer_type=args.enc_positionwise_layer_type,
            enc_positionwise_conv_kernel_size=(args.enc_positionwise_conv_kernel_size),
            enc_macaron_style=args.enc_macaron_style,
            enc_pos_enc_layer_type=args.enc_pos_enc_layer_type,
            enc_selfattention_layer_type=args.enc_selfattention_layer_type,
            enc_activation_type=args.enc_activation_type,
            enc_use_cnn_module=args.enc_use_cnn_module,
            enc_cnn_module_kernel=args.enc_cnn_module_kernel,
            enc_padding_idx=args.enc_padding_idx,
            dec_attention_dim=args.dec_attention_dim,
            dec_attention_heads=args.dec_attention_heads,
            dec_linear_units=args.dec_linear_units,
            dec_num_blocks=args.dec_num_blocks,
            dec_dropout_rate=args.dec_dropout_rate,
            dec_positional_dropout_rate=args.dec_positional_dropout_rate,
            dec_attention_dropout_rate=args.dec_attention_dropout_rate,
            dec_input_layer=args.dec_input_layer,
            dec_normalize_before=args.dec_normalize_before,
            dec_concat_after=args.dec_concat_after,
            dec_positionwise_layer_type=args.dec_positionwise_layer_type,
            dec_positionwise_conv_kernel_size=(args.dec_positionwise_conv_kernel_size),
            dec_macaron_style=args.dec_macaron_style,
            dec_pos_enc_layer_type=args.dec_pos_enc_layer_type,
            dec_selfattention_layer_type=args.dec_selfattention_layer_type,
            dec_activation_type=args.dec_activation_type,
            dec_use_cnn_module=args.dec_use_cnn_module,
            dec_cnn_module_kernel=args.dec_cnn_module_kernel,
            dec_padding_idx=args.dec_padding_idx,
            device=device,
        )
    else:
        raise ValueError("Not Support Model Type %s" % args.model_type)
    logging.info(f"{model}")
    logging.info(f"The model has {count_parameters(model):,} trainable parameters")

    # Load model weights
    logging.info(f"Loading pretrained weights from {args.model_file}")
    checkpoint = torch.load(args.model_file, map_location=device)
    state_dict = checkpoint["state_dict"]
    model_dict = model.state_dict()
    state_dict_new = {}
    para_list = []

    for k, v in state_dict.items():
        # assert k in model_dict
        if (
            k == "normalizer.mean"
            or k == "normalizer.std"
            or k == "mel_normalizer.mean"
            or k == "mel_normalizer.std"
        ):
            continue
        if model_dict[k].size() == state_dict[k].size():
            state_dict_new[k] = v
        else:
            para_list.append(k)

    logging.info(
        f"Total {len(state_dict)} parameter sets, "
        f"loaded {len(state_dict_new)} parameter set"
    )

    if len(para_list) > 0:
        logging.warning(f"Not loading {para_list} because of different sizes")
    model.load_state_dict(state_dict_new)
    logging.info(f"Loaded checkpoint {args.model_file}")
    model = model.to(device)
    model.eval()

    # Decode
    test_set = SVSDataset(
        align_root_path=args.test_align,
        pitch_beat_root_path=args.test_pitch,
        wav_root_path=args.test_wav,
        char_max_len=args.char_max_len,
        max_len=args.num_frames,
        sr=args.sampling_rate,
        preemphasis=args.preemphasis,
        nfft=args.nfft,
        frame_shift=args.frame_shift,
        frame_length=args.frame_length,
        n_mels=args.n_mels,
        power=args.power,
        max_db=args.max_db,
        ref_db=args.ref_db,
        standard=args.standard,
        sing_quality=args.sing_quality,
    )
    collate_fn_svs = SVSCollator(
        args.num_frames,
        args.char_max_len,
        args.use_asr_post,
        args.phone_size,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=1,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_fn_svs,
        pin_memory=True,
    )

    if args.loss == "l1":
        criterion = MaskedLoss("l1")
    elif args.loss == "mse":
        criterion = MaskedLoss("mse")
    else:
        raise ValueError("Not Support Loss Type")

    losses = AverageMeter()
    spec_losses = AverageMeter()
    if args.perceptual_loss > 0:
        pe_losses = AverageMeter()
    if args.n_mels > 0:
        mel_losses = AverageMeter()
        mcd_metric = AverageMeter()
        f0_distortion_metric, vuv_error_metric = (
            AverageMeter(),
            AverageMeter(),
        )
        if args.double_mel_loss:
            double_mel_losses = AverageMeter()
    model.eval()

    if not os.path.exists(args.prediction_path):
        os.makedirs(args.prediction_path)

    f0_ground_truth_all = np.reshape(np.array([]), (-1, 1))
    f0_synthesis_all = np.reshape(np.array([]), (-1, 1))
    start_t_test = time.time()

    # preload vocoder model
    if args.vocoder_category == "wavernn":
        voc_model = WaveRNN(
            rnn_dims=512,
            fc_dims=512,
            bits=9,
            pad=2,
            upsample_factors=(
                5,
                5,
                11,
            ),
            feat_dims=80,
            compute_dims=128,
            res_out_dims=128,
            res_blocks=10,
            hop_length=275,  # 12.5ms - in line with Tacotron 2 paper
            sample_rate=22050,
            mode="MOL",
        ).to(device)

        voc_model.load("./weights/wavernn/latest_weights.pyt")

    with torch.no_grad():
        for (
            step,
            (
                phone,
                beat,
                pitch,
                spec,
                real,
                imag,
                length,
                chars,
                char_len_list,
                mel,
            ),
        ) in enumerate(test_loader, 1):
            # if step >= args.decode_sample:
            #     break
            phone = phone.to(device)
            beat = beat.to(device)
            pitch = pitch.to(device).float()
            spec = spec.to(device).float()
            mel = mel.to(device).float()
            real = real.to(device).float()
            imag = imag.to(device).float()
            length_mask = length.unsqueeze(2)
            length_mel_mask = length_mask.repeat(1, 1, mel.shape[2]).float()
            length_mask = length_mask.repeat(1, 1, spec.shape[2]).float()
            length_mask = length_mask.to(device)
            length_mel_mask = length_mel_mask.to(device)
            length = length.to(device)
            char_len_list = char_len_list.to(device)

            if not args.use_asr_post:
                chars = chars.to(device)
                char_len_list = char_len_list.to(device)
            else:
                phone = phone.float()

            if args.model_type == "GLU_Transformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            elif args.model_type == "LSTM":
                output, hidden, output_mel, output_mel2 = model(phone, pitch, beat)
                att = None
            elif args.model_type == "GRU_gs":
                output, att, output_mel = model(spec, phone, pitch, beat, length, args)
                att = None
            elif args.model_type == "PureTransformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            elif args.model_type == "Conformer":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )
            elif args.model_type == "Comformer_full":
                output, att, output_mel, output_mel2 = model(
                    chars,
                    phone,
                    pitch,
                    beat,
                    pos_char=char_len_list,
                    pos_spec=length,
                )

            spec_origin = spec.clone()
            # spec_origin = spec
            if args.normalize:
                sepc_normalizer = GlobalMVN(args.stats_file)
                mel_normalizer = GlobalMVN(args.stats_mel_file)
                spec, _ = sepc_normalizer(spec, length)
                mel, _ = mel_normalizer(mel, length)

            spec_loss = criterion(output, spec, length_mask)
            if args.n_mels > 0:
                mel_loss = criterion(output_mel, mel, length_mel_mask)
            else:
                mel_loss = 0

            final_loss = mel_loss + spec_loss

            losses.update(final_loss.item(), phone.size(0))
            spec_losses.update(spec_loss.item(), phone.size(0))
            if args.n_mels > 0:
                mel_losses.update(mel_loss.item(), phone.size(0))

            # normalize inverse stage
            if args.normalize and args.stats_file:
                output, _ = sepc_normalizer.inverse(output, length)
                # spec,_ = sepc_normalizer.inverse(spec,length)

            (mcd_value, length_sum,) = Metrics.Calculate_melcd_fromLinearSpectrum(
                output, spec_origin, length, args
            )
            (
                f0_distortion_value,
                voiced_frame_number_step,
                vuv_error_value,
                frame_number_step,
                f0_ground_truth_step,
                f0_synthesis_step,
            ) = Metrics.Calculate_f0RMSE_VUV_CORR_fromWav(
                output, spec_origin, length, args, "test"
            )
            f0_ground_truth_all = np.concatenate(
                (f0_ground_truth_all, f0_ground_truth_step), axis=0
            )
            f0_synthesis_all = np.concatenate(
                (f0_synthesis_all, f0_synthesis_step), axis=0
            )

            mcd_metric.update(mcd_value, length_sum)
            f0_distortion_metric.update(f0_distortion_value, voiced_frame_number_step)
            vuv_error_metric.update(vuv_error_value, frame_number_step)

            if step % 1 == 0:
                if args.vocoder_category == "griffin":
                    log_figure(
                        step,
                        output,
                        spec_origin,
                        att,
                        length,
                        args.prediction_path,
                        args,
                    )
                elif args.vocoder_category == "wavernn":
                    log_mel(
                        step,
                        output_mel,
                        spec_origin,
                        att,
                        length,
                        args.prediction_path,
                        args,
                        voc_model,
                    )
                out_log = (
                    "step {}:train_loss{:.4f};"
                    "spec_loss{:.4f};mcd_value{:.4f};".format(
                        step,
                        losses.avg,
                        spec_losses.avg,
                        mcd_metric.avg,
                    )
                )
                if args.perceptual_loss > 0:
                    out_log += " pe_loss {:.4f}; ".format(pe_losses.avg)
                if args.n_mels > 0:
                    out_log += " mel_loss {:.4f}; ".format(mel_losses.avg)
                    if args.double_mel_loss:
                        out_log += " dmel_loss {:.4f}; ".format(double_mel_losses.avg)
                end = time.time()
                logging.info(f"{out_log} -- sum_time: {(end - start_t_test)}s")

    end_t_test = time.time()

    out_log = "Test Stage: "
    out_log += "spec_loss: {:.4f} ".format(spec_losses.avg)
    if args.n_mels > 0:
        out_log += "mel_loss: {:.4f}, ".format(mel_losses.avg)
    # if args.perceptual_loss > 0:
    #     out_log += 'pe_loss: {:.4f}, '.format(train_info['pe_loss'])

    f0_corr = Metrics.compute_f0_corr(f0_ground_truth_all, f0_synthesis_all)

    out_log += "\n\t mcd_value {:.4f} dB ".format(mcd_metric.avg)
    out_log += (
        " f0_rmse_value {:.4f} Hz, "
        "vuv_error_value {:.4f} %, F0_CORR {:.4f}; ".format(
            np.sqrt(f0_distortion_metric.avg),
            vuv_error_metric.avg * 100,
            f0_corr,
        )
    )
    logging.info("{} time: {:.2f}s".format(out_log, end_t_test - start_t_test))