def run_trt(engine, pyt_components): '''Runs TRT inference for accuracy evaluation ''' baked_seq_len = engine.get_binding_shape(0)[1] wers = [] preds = [] with engine.create_execution_context() as context, torch.no_grad(): for data in tqdm(pyt_components['data_layer'].data_iterator): tensors = [] for d in data: tensors.append(d.to(torch.device("cuda"))) input_tensor = (tensors[0], tensors[1]) am_input = pyt_components['audio_preprocessor'](x=input_tensor) # Pad or cut to the neccessary engine length am_input = perfutils.adjust_shape(am_input, baked_seq_len) batch_size = am_input[0].shape[0] torch.cuda.synchronize() # Run TRT inference trt_out, _, _, _ = do_inference(context=context, inp=am_input, batch_size=batch_size) trt_out = perfutils.torchify_trt_out(trt_out, batch_size=batch_size) wer, pred = perfutils.get_results(log_probs=trt_out, original_tensors=tensors, batch_size=batch_size) wers.append(wer) preds.append(pred) return wers, preds
def main(args): # Get shared utility across PyTorch and TRT pyt_components, saved_onnx = perfutils.get_pytorch_components_and_onnx(args) # Get a TRT engine. See function for argument parsing logic engine = get_engine(args) if args.wav: audio_processor = pyt_components['audio_preprocessor'] audio_processor.eval() greedy_decoder = GreedyCTCDecoder() input_wav, seq_len = pyt_components['input_wav'] features = audio_processor((input_wav, seq_len)) features = perfutils.adjust_shape(features, args.seq_len) with engine.create_execution_context() as context: t_log_probs_e, copyto, inference, copyfrom= perfprocedures.do_inference(context, features[0], 1) log_probs=perfutils.torchify_trt_out(t_log_probs_e, 1) t_predictions_e = greedy_decoder(log_probs=log_probs) hypotheses = __ctc_decoder_predictions_tensor(t_predictions_e, labels=perfutils.get_vocab()) print("INTERENCE TIME: {} ms".format(inference*1000.0)) print("TRANSCRIPT: ", hypotheses[0]) return wer, preds, times = perfprocedures.compare_times_trt_pyt_exhaustive(engine, pyt_components, num_steps=args.num_steps) string_header, string_data = perfutils.do_csv_export(wer, times, args.batch_size, args.seq_len) if args.csv_path is not None: with open(args.csv_path, 'a+') as f: # See if header is there, if so, check that it matches f.seek(0) # Read from start of file existing_header = f.readline() if existing_header == "": f.write(string_header) f.write("\n") elif existing_header[:-1] != string_header: raise Exception(f"Writing to existing CSV with incorrect format\nProduced:\n{string_header}\nFound:\n{existing_header}\nIf you intended to write to a new results csv, please change the csv_path argument") f.seek(0,2) # Write to end of file f.write(string_data) f.write("\n") else: print(string_header) print(string_data) if args.trt_prediction_path is not None: with open(args.trt_prediction_path, 'w') as fp: fp.write('\n'.join(preds['trt'])) if args.pyt_prediction_path is not None: with open(args.pyt_prediction_path, 'w') as fp: fp.write('\n'.join(preds['pyt']))