Beispiel #1
0
def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements, fp16):

    mel = mel.unsqueeze(3)
    mel_size = mel.size(2)
    batch_size = mel.size(0)
    stride = 256
    kernel_size = 1024
    n_group = 8
    z_size = (mel_size-1)*stride+(kernel_size-1)+1
    z_size = z_size - (kernel_size-stride)
    z_size = z_size//n_group
    z = torch.randn(batch_size, n_group, z_size, 1).cuda()
    audios = torch.zeros(batch_size, mel_size*stride).cuda()

    if fp16:
        z = z.half()
        mel = mel.half()
        audios = audios.half()

    waveglow_tensors = {
        # inputs
        'mel': mel, 'z': z,
        # outputs
        'audio': audios
    }
    print("Running WaveGlow")
    with MeasureTime(measurements, "waveglow_time"):
        run_trt_engine(waveglow_context, waveglow, waveglow_tensors)

    return audios
Beispiel #2
0
def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements, fp16):

    mel_size = mel.size(2)
    batch_size = mel.size(0)
    stride = 256
    n_group = 8
    z_size = mel_size*stride
    z_size = z_size//n_group
    z = torch.randn(batch_size, n_group, z_size).cuda()
    audios = torch.zeros(batch_size, mel_size*stride).cuda()

    if fp16:
        z = z.half()
        mel = mel.half()
        audios = audios.half()

    waveglow_tensors = {
        "inputs" :
        {'mel': mel, 'z': z},
        "outputs" :
        {'audio': audios}
    }
    print("Running WaveGlow")
    with MeasureTime(measurements, "waveglow_time"):
        run_trt_engine(waveglow_context, waveglow, waveglow_tensors)

    return audios
Beispiel #3
0
def infer_tacotron2_trt(encoder, decoder_iter, postnet,
                        encoder_context, decoder_context, postnet_context,
                        sequences, sequence_lengths, measurements, fp16):

    memory = torch.zeros((len(sequence_lengths), sequence_lengths[0], 512)).cuda()
    if fp16:
        memory = memory.half()
    device = memory.device
    dtype = memory.dtype

    processed_memory = torch.zeros((len(sequence_lengths),sequence_lengths[0],128), device=device, dtype=dtype)
    lens = torch.zeros_like(sequence_lengths)

    encoder_tensors = {
        # inputs
        'sequences': sequences, 'sequence_lengths': sequence_lengths,
        # outputs
        'memory': memory, 'lens': lens, 'processed_memory': processed_memory
    }

    print("Running Tacotron2 Encoder")
    with MeasureTime(measurements, "tacotron2_encoder_time"):
        run_trt_engine(encoder_context, encoder, encoder_tensors)

    device = memory.device
    mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device = device)
    not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device = device)
    mel_outputs, gate_outputs, alignments = (torch.zeros(1, device = device), torch.zeros(1, device = device), torch.zeros(1, device = device))
    gate_threshold = 0.5
    max_decoder_steps = 1664
    first_iter = True

    decoder_inputs = init_decoder_inputs(memory, processed_memory, sequence_lengths)
    decoder_outputs = init_decoder_outputs(memory, sequence_lengths)

    print("Running Tacotron2 Decoder")
    while True:
        decoder_tensors = init_decoder_tensors(decoder_inputs, decoder_outputs)
        with MeasureTime(measurements, "step"):
            run_trt_engine(decoder_context, decoder_iter, decoder_tensors)

        if first_iter:
            mel_outputs = torch.unsqueeze(decoder_outputs[7], 2)
            gate_outputs = torch.unsqueeze(decoder_outputs[8], 2)
            alignments = torch.unsqueeze(decoder_outputs[4], 2)
            measurements['tacotron2_decoder_time'] = measurements['step']
            first_iter = False
        else:
            mel_outputs = torch.cat((mel_outputs, torch.unsqueeze(decoder_outputs[7], 2)), 2)
            gate_outputs = torch.cat((gate_outputs, torch.unsqueeze(decoder_outputs[8], 2)), 2)
            alignments = torch.cat((alignments, torch.unsqueeze(decoder_outputs[4], 2)), 2)
            measurements['tacotron2_decoder_time'] += measurements['step']

        dec = torch.le(torch.sigmoid(decoder_outputs[8]), gate_threshold).to(torch.int32).squeeze(1)
        not_finished = not_finished*dec
        mel_lengths += not_finished

        if torch.sum(not_finished) == 0:
            print("Stopping after",mel_outputs.size(2),"decoder steps")
            break
        if mel_outputs.size(2) == max_decoder_steps:
            print("Warning! Reached max decoder steps")
            break

        decoder_inputs, decoder_outputs = swap_inputs_outputs(decoder_inputs, decoder_outputs)

    mel_outputs_postnet = torch.zeros_like(mel_outputs, device=device, dtype=dtype)

    postnet_tensors = {
        # inputs
        'mel_outputs': mel_outputs,
        # outputs
        'mel_outputs_postnet': mel_outputs_postnet
    }
    print("Running Tacotron2 Postnet")
    with MeasureTime(measurements, "tacotron2_postnet_time"):
        run_trt_engine(postnet_context, postnet, postnet_tensors)

    print("Tacotron2 Postnet done")

    return mel_outputs_postnet, mel_lengths