Esempio n. 1
0
def infer_waveglow_trt(waveglow, waveglow_context, mel, measurements, fp16):

    mel_size = mel.size(2)
    batch_size = mel.size(0)
    stride = 256
    n_group = 8
    z_size = mel_size * stride
    z_size = z_size // n_group
    z = torch.randn(batch_size, n_group, z_size).cuda()
    audios = torch.zeros(batch_size, mel_size * stride).cuda()

    mel = mel.unsqueeze(3)
    z = z.unsqueeze(3)

    if fp16:
        z = z.half()
        mel = mel.half()
        audios = audios.half()

    waveglow_tensors = {
        "inputs": {
            'mel': mel,
            'z': z
        },
        "outputs": {
            'audio': audios
        }
    }

    print("Running WaveGlow with TensorRT")
    with MeasureTime(measurements, "waveglow_time"):
        run_trt_engine(waveglow_context, waveglow, waveglow_tensors)

    return audios
Esempio n. 2
0
def infer_tacotron2_trt(encoder, decoder_iter, postnet, encoder_context,
                        decoder_context, postnet_context, sequences,
                        sequence_lengths, measurements, fp16, loop):

    batch_size = len(sequence_lengths)
    max_sequence_len = sequence_lengths[0]
    memory = torch.zeros((batch_size, max_sequence_len, 512)).cuda()
    if fp16:
        memory = memory.half()
    device = memory.device
    dtype = memory.dtype

    processed_memory = torch.zeros((batch_size, max_sequence_len, 128),
                                   device=device,
                                   dtype=dtype)
    lens = torch.zeros_like(sequence_lengths)
    print(f"batch_size: {batch_size}, max sequence length: {max_sequence_len}")

    encoder_tensors = {
        "inputs": {
            'sequences': sequences,
            'sequence_lengths': sequence_lengths
        },
        "outputs": {
            'memory': memory,
            'lens': lens,
            'processed_memory': processed_memory
        }
    }

    print("Running Tacotron2 Encoder")
    with MeasureTime(measurements, "tacotron2_encoder_time"):
        run_trt_engine(encoder_context, encoder, encoder_tensors)
    max_decoder_steps = 1024
    device = memory.device
    mel_lengths = torch.zeros([memory.size(0)],
                              dtype=torch.int32,
                              device=device)
    not_finished = torch.ones([memory.size(0)],
                              dtype=torch.int32,
                              device=device)
    mel_outputs = torch.ones((batch_size, 80, max_decoder_steps),
                             device=device,
                             dtype=dtype).cuda()
    gate_threshold = 0.5
    first_iter = True

    decoder_inputs = init_decoder_inputs(memory, processed_memory,
                                         sequence_lengths)
    decoder_outputs = init_decoder_outputs(memory, sequence_lengths)

    if loop:
        if decoder_context is None:
            print("Running Tacotron2 Decoder with loop with ONNX-RT")
            decoder_inputs_onnxrt = [
                x.cpu().numpy().copy() for x in decoder_inputs
            ]
            import onnx
            import onnxruntime
            sess = onnxruntime.InferenceSession(decoder_iter)

            with MeasureTime(measurements, "tacotron2_decoder_time"):
                result = sess.run(
                    ["mel_outputs", "mel_lengths_t"], {
                        'decoder_input_0': decoder_inputs_onnxrt[0],
                        'attention_hidden_0': decoder_inputs_onnxrt[1],
                        'attention_cell_0': decoder_inputs_onnxrt[2],
                        'decoder_hidden_0': decoder_inputs_onnxrt[3],
                        'decoder_cell_0': decoder_inputs_onnxrt[4],
                        'attention_weights_0': decoder_inputs_onnxrt[5],
                        'attention_weights_cum_0': decoder_inputs_onnxrt[6],
                        'attention_context_0': decoder_inputs_onnxrt[7],
                        'memory': decoder_inputs_onnxrt[8],
                        'processed_memory': decoder_inputs_onnxrt[9],
                        'mask': decoder_inputs_onnxrt[10]
                    })

            mel_outputs = torch.tensor(result[0], device=device)
            mel_lengths = torch.tensor(result[1], device=device)
        else:
            print("Running Tacotron2 Decoder with loop")
            decoder_tensors = {
                "inputs": {
                    'decoder_input_0': decoder_inputs[0],
                    'attention_hidden_0': decoder_inputs[1],
                    'attention_cell_0': decoder_inputs[2],
                    'decoder_hidden_0': decoder_inputs[3],
                    'decoder_cell_0': decoder_inputs[4],
                    'attention_weights_0': decoder_inputs[5],
                    'attention_weights_cum_0': decoder_inputs[6],
                    'attention_context_0': decoder_inputs[7],
                    'memory': decoder_inputs[8],
                    'processed_memory': decoder_inputs[9],
                    'mask': decoder_inputs[10]
                },
                "outputs": {
                    'mel_outputs': mel_outputs,
                    'mel_lengths_t': mel_lengths
                }
            }

            with MeasureTime(measurements, "tacotron2_decoder_time"):
                run_trt_engine(decoder_context, decoder_iter, decoder_tensors)
            mel_outputs = mel_outputs[:, :, :torch.max(mel_lengths)]

    else:
        print("Running Tacotron2 Decoder")
        measurements_decoder = {}
        while True:
            decoder_tensors = init_decoder_tensors(decoder_inputs,
                                                   decoder_outputs)
            with MeasureTime(measurements_decoder, "step"):
                run_trt_engine(decoder_context, decoder_iter, decoder_tensors)

            if first_iter:
                mel_outputs = torch.unsqueeze(decoder_outputs[7], 2)
                gate_outputs = torch.unsqueeze(decoder_outputs[8], 2)
                alignments = torch.unsqueeze(decoder_outputs[4], 2)
                measurements['tacotron2_decoder_time'] = measurements_decoder[
                    'step']
                first_iter = False
            else:
                mel_outputs = torch.cat(
                    (mel_outputs, torch.unsqueeze(decoder_outputs[7], 2)), 2)
                gate_outputs = torch.cat(
                    (gate_outputs, torch.unsqueeze(decoder_outputs[8], 2)), 2)
                alignments = torch.cat(
                    (alignments, torch.unsqueeze(decoder_outputs[4], 2)), 2)
                measurements['tacotron2_decoder_time'] += measurements_decoder[
                    'step']

            dec = torch.le(torch.sigmoid(decoder_outputs[8]),
                           gate_threshold).to(torch.int32).squeeze(1)
            not_finished = not_finished * dec
            mel_lengths += not_finished

            if torch.sum(not_finished) == 0:
                print("Stopping after", mel_outputs.size(2), "decoder steps")
                break
            if mel_outputs.size(2) == max_decoder_steps:
                print("Warning! Reached max decoder steps")
                break

            decoder_inputs, decoder_outputs = swap_inputs_outputs(
                decoder_inputs, decoder_outputs)

    mel_outputs = mel_outputs.clone().detach()
    mel_outputs_postnet = torch.zeros_like(mel_outputs,
                                           device=device,
                                           dtype=dtype)

    postnet_tensors = {
        "inputs": {
            'mel_outputs': mel_outputs
        },
        "outputs": {
            'mel_outputs_postnet': mel_outputs_postnet
        }
    }
    print("Running Tacotron2 Postnet")
    with MeasureTime(measurements, "tacotron2_postnet_time"):
        run_trt_engine(postnet_context, postnet, postnet_tensors)

    print("Tacotron2 Postnet done")

    return mel_outputs_postnet, mel_lengths
def infer_tacotron2_trt(encoder, decoder_iter, postnet, encoder_context,
                        decoder_context, postnet_context, sequences,
                        sequence_lengths, measurements, fp16):

    memory = torch.zeros(
        (len(sequence_lengths), sequence_lengths[0], 512)).cuda()
    if fp16:
        memory = memory.half()
    device = memory.device
    dtype = memory.dtype

    processed_memory = torch.zeros(
        (len(sequence_lengths), sequence_lengths[0], 128),
        device=device,
        dtype=dtype)
    lens = torch.zeros_like(sequence_lengths)

    encoder_tensors = {
        "inputs": {
            'sequences': sequences,
            'sequence_lengths': sequence_lengths
        },
        "outputs": {
            'memory': memory,
            'lens': lens,
            'processed_memory': processed_memory
        }
    }

    print("Running Tacotron2 Encoder")
    with MeasureTime(measurements, "tacotron2_encoder_time"):
        run_trt_engine(encoder_context, encoder, encoder_tensors)

    device = memory.device
    mel_lengths = torch.zeros([memory.size(0)],
                              dtype=torch.int32,
                              device=device)
    not_finished = torch.ones([memory.size(0)],
                              dtype=torch.int32,
                              device=device)
    mel_outputs, gate_outputs, alignments = (torch.zeros(1, device=device),
                                             torch.zeros(1, device=device),
                                             torch.zeros(1, device=device))
    gate_threshold = 0.5
    max_decoder_steps = 1664
    first_iter = True

    decoder_inputs = init_decoder_inputs(memory, processed_memory,
                                         sequence_lengths)
    decoder_outputs = init_decoder_outputs(memory, sequence_lengths)

    print("Running Tacotron2 Decoder")
    measurements_decoder = {}
    while True:
        decoder_tensors = init_decoder_tensors(decoder_inputs, decoder_outputs)
        with MeasureTime(measurements_decoder, "step"):
            run_trt_engine(decoder_context, decoder_iter, decoder_tensors)

        if first_iter:
            mel_outputs = torch.unsqueeze(decoder_outputs[7], 2)
            gate_outputs = torch.unsqueeze(decoder_outputs[8], 2)
            alignments = torch.unsqueeze(decoder_outputs[4], 2)
            measurements['tacotron2_decoder_time'] = measurements_decoder[
                'step']
            first_iter = False
        else:
            mel_outputs = torch.cat(
                (mel_outputs, torch.unsqueeze(decoder_outputs[7], 2)), 2)
            gate_outputs = torch.cat(
                (gate_outputs, torch.unsqueeze(decoder_outputs[8], 2)), 2)
            alignments = torch.cat(
                (alignments, torch.unsqueeze(decoder_outputs[4], 2)), 2)
            measurements['tacotron2_decoder_time'] += measurements_decoder[
                'step']

        dec = torch.le(torch.sigmoid(decoder_outputs[8]),
                       gate_threshold).to(torch.int32).squeeze(1)
        not_finished = not_finished * dec
        mel_lengths += not_finished

        if torch.sum(not_finished) == 0:
            print("Stopping after", mel_outputs.size(2), "decoder steps")
            break
        if mel_outputs.size(2) == max_decoder_steps:
            print("Warning! Reached max decoder steps")
            break

        decoder_inputs, decoder_outputs = swap_inputs_outputs(
            decoder_inputs, decoder_outputs)

    mel_outputs_postnet = torch.zeros_like(mel_outputs,
                                           device=device,
                                           dtype=dtype)

    postnet_tensors = {
        "inputs": {
            'mel_outputs': mel_outputs
        },
        "outputs": {
            'mel_outputs_postnet': mel_outputs_postnet
        }
    }
    print("Running Tacotron2 Postnet")
    with MeasureTime(measurements, "tacotron2_postnet_time"):
        run_trt_engine(postnet_context, postnet, postnet_tensors)

    print("Tacotron2 Postnet done")

    return mel_outputs_postnet, mel_lengths