Exemple #1
0
def run_trt(engine, pyt_components):
    '''Runs TRT inference for accuracy evaluation
    '''
    baked_seq_len = engine.get_binding_shape(0)[1]
    wers = []
    preds = []
    with engine.create_execution_context() as context, torch.no_grad():
        for data in tqdm(pyt_components['data_layer'].data_iterator):
            tensors = []
            for d in data:
                tensors.append(d.to(torch.device("cuda")))
            input_tensor = (tensors[0], tensors[1])
            am_input = pyt_components['audio_preprocessor'](x=input_tensor)
            # Pad or cut to the neccessary engine length
            am_input = perfutils.adjust_shape(am_input, baked_seq_len)
            batch_size = am_input[0].shape[0]
            torch.cuda.synchronize()
            # Run TRT inference
            trt_out, _, _, _ = do_inference(context=context,
                                            inp=am_input,
                                            batch_size=batch_size)
            trt_out = perfutils.torchify_trt_out(trt_out,
                                                 batch_size=batch_size)
            wer, pred = perfutils.get_results(log_probs=trt_out,
                                              original_tensors=tensors,
                                              batch_size=batch_size)
            wers.append(wer)
            preds.append(pred)

    return wers, preds
Exemple #2
0
def main(args):

    # Get shared utility across PyTorch and TRT
    pyt_components, saved_onnx = perfutils.get_pytorch_components_and_onnx(args)

    # Get a TRT engine. See function for argument parsing logic
    engine = get_engine(args)

    if args.wav:
        audio_processor = pyt_components['audio_preprocessor']
        audio_processor.eval()
        greedy_decoder = GreedyCTCDecoder()
        input_wav, seq_len = pyt_components['input_wav']
        features = audio_processor((input_wav, seq_len))
        features = perfutils.adjust_shape(features, args.seq_len)
        with engine.create_execution_context() as context:
            t_log_probs_e, copyto, inference, copyfrom= perfprocedures.do_inference(context, features[0], 1)
        log_probs=perfutils.torchify_trt_out(t_log_probs_e, 1)
        
        t_predictions_e = greedy_decoder(log_probs=log_probs)
        hypotheses = __ctc_decoder_predictions_tensor(t_predictions_e, labels=perfutils.get_vocab())
        print("INTERENCE TIME: {} ms".format(inference*1000.0))
        print("TRANSCRIPT: ", hypotheses[0])

        return

    
    wer, preds, times = perfprocedures.compare_times_trt_pyt_exhaustive(engine,
                                                                        pyt_components,
                                                                        num_steps=args.num_steps)
    string_header, string_data = perfutils.do_csv_export(wer, times, args.batch_size, args.seq_len)
    if args.csv_path is not None:
        with open(args.csv_path, 'a+') as f:
            # See if header is there, if so, check that it matches
            f.seek(0) # Read from start of file
            existing_header = f.readline()
            if existing_header == "":
                f.write(string_header)
                f.write("\n")
            elif existing_header[:-1] != string_header:
                raise Exception(f"Writing to existing CSV with incorrect format\nProduced:\n{string_header}\nFound:\n{existing_header}\nIf you intended to write to a new results csv, please change the csv_path argument")
            f.seek(0,2) # Write to end of file
            f.write(string_data)
            f.write("\n")
    else:
        print(string_header)
        print(string_data)

    if args.trt_prediction_path is not None:
        with open(args.trt_prediction_path, 'w') as fp:
            fp.write('\n'.join(preds['trt']))
     
    if args.pyt_prediction_path is not None:
        with open(args.pyt_prediction_path, 'w') as fp:
            fp.write('\n'.join(preds['pyt']))