def setup(args): torch.set_grad_enabled(False) checkpoint = torch.load(args.checkpoint, map_location='cpu') args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map( checkpoint['args'].get, [ 'sample_rate', 'window_size', 'window_stride', 'window', 'num_input_features' ]) frontend = models.LogFilterBankFrontend( args.num_input_features, args.sample_rate, args.window_size, args.window_stride, args.window, dither=args.dither, dither0=args.dither0, #eps = 1e-6, normalize_signal=args.normalize_signal, debug_short_long_records_normalize_signal_multiplier=args. debug_short_long_records_normalize_signal_multiplier) # for legacy compat text_config = json.load( open(checkpoint['args'].get('text_config', args.text_config))) text_pipeline = text_processing.ProcessingPipeline.make( text_config, checkpoint['args'].get('text_pipelines', args.text_pipelines)[0]) model = getattr(models, args.model or checkpoint['args']['model'])( args.num_input_features, [text_pipeline.tokenizer.vocab_size], frontend=frontend if args.frontend_in_model else None, check_time_dim_padded=False, dict=lambda logits, log_probs, olen, **kwargs: (log_probs[0], logits[0], olen[0])) model.load_state_dict(checkpoint['model_state_dict'], strict=False) model = model.to(args.device) model.eval() model.fuse_conv_bn_eval() if args.device != 'cpu': model, *_ = models.data_parallel_and_autocast(model, opt_level=args.fp16) generator = transcript_generators.GreedyCTCGenerator() return text_pipeline, frontend, model, generator
def setup(args): torch.set_grad_enabled(False) checkpoint = torch.load(args.checkpoint, map_location='cpu') args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map( checkpoint['args'].get, [ 'sample_rate', 'window_size', 'window_stride', 'window', 'num_input_features' ]) frontend = models.LogFilterBankFrontend(args.num_input_features, args.sample_rate, args.window_size, args.window_stride, args.window, eps=1e-6) labels = datasets.Labels(datasets.Language(checkpoint['args']['lang']), name='char') model = getattr(models, args.model or checkpoint['args']['model'])( args.num_input_features, [len(labels)], frontend=frontend, dict=lambda logits, log_probs, olen, **kwargs: (logits[0], olen[0])) model.load_state_dict(checkpoint['model_state_dict'], strict=False) model = model.to(args.device) model.eval() model.fuse_conv_bn_eval() if args.device != 'cpu': model, *_ = models.data_parallel_and_autocast(model, opt_level=args.fp16) decoder = decoders.GreedyDecoder( ) if args.decoder == 'GreedyDecoder' else decoders.BeamSearchDecoder( labels, lm_path=args.lm, beam_width=args.beam_width, beam_alpha=args.beam_alpha, beam_beta=args.beam_beta, num_workers=args.num_workers, topk=args.decoder_topk) return labels, frontend, model, decoder
def main(args): use_cuda = 'cuda' in args.device if use_cuda: print( f'CUDA_VISIBLE_DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")!r}' ) print('initializing model...') if args.onnx: # todo: pass dict with provider setting when we will migrate to onnxruntime>=1.7 onnxruntime_session = onnxruntime.InferenceSession(args.onnx) if args.device == 'cpu': onnxruntime_session.set_providers(['CPUExecutionProvider']) model = lambda x: onnxruntime_session.run( None, dict(x=x, xlen=[1.0] * len(x))) load_batch = lambda x: x.numpy() else: assert args.checkpoint is not None and os.path.isfile(args.checkpoint) checkpoint = torch.load(args.checkpoint, map_location='cpu') args.model, args.lang, args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = ( checkpoint['args'].get(key) for key in [ 'model', 'lang', 'sample_rate', 'window_size', 'window_stride', 'window', 'num_input_features' ]) text_config = json.load(open(args.text_config)) text_pipelines = [] for pipeline_name in args.text_pipelines: text_pipelines.append( text_processing.ProcessingPipeline.make( text_config, pipeline_name)) frontend = models.LogFilterBankFrontend( out_channels=args.num_input_features, sample_rate=args.sample_rate, window_size=args.window_size, window_stride=args.window_stride, window=args.window, stft_mode=args.stft_mode, ) model = getattr(models, args.model)( num_input_features=args.num_input_features, num_classes=[ pipeline.tokenizer.vocab_size for pipeline in text_pipelines ], frontend=frontend, dict=lambda logits, log_probs, olen, **kwargs: logits[0]) model.load_state_dict(checkpoint['model_state_dict'], strict=False) model.to(args.device) model.eval() model.fuse_conv_bn_eval() if use_cuda: model, *_ = models.data_parallel_and_autocast(model, opt_level=args.fp16, data_parallel=False) load_batch = lambda x: x.to(args.device, non_blocking=True) batch_width = int(math.ceil(args.T * args.sample_rate / 128) * 128) example_time = batch_width / args.sample_rate batch = torch.rand(args.B, batch_width) batch = batch.pin_memory() print( f'batch [{args.B}, {batch_width}] | audio {args.B * example_time:.2f} sec\n' ) def tictoc(): if use_cuda: torch.cuda.synchronize() return time.time() print(f'Warming up for {args.warmup_iterations} iterations...') tic_wall = tictoc() if use_cuda: torch.backends.cudnn.benchmark = True torch.set_grad_enabled(False) # no back-prop - no gradients for i in range(args.warmup_iterations): model(load_batch(batch)) print(f'Warmup done in {tictoc() - tic_wall:.1f} sec\n') n_requests = int(round(args.benchmark_duration * args.rps)) print( f'Starting {args.benchmark_duration} second benchmark ({n_requests} requests, rps {args.rps:.1f})...' ) schedule = torch.rand( n_requests, dtype=torch.float64) # uniform random distribution of requests schedule = schedule.sort()[0] * args.benchmark_duration + tictoc() gaps = schedule[1:] - schedule[:-1] print(f'avg gap between requests: {gaps.mean() * 1e3:.1f} ms') latency_times, idle_times = [], [] slow_warning = False for t_request in schedule: t_request = t_request.item() tic = tictoc() if tic < t_request: # no requests yet. at this point prod would wait for the next request sleep_time = t_request - tic idle_times.append(sleep_time) time.sleep(sleep_time) logits = model(load_batch(batch)) if not args.onnx: logits.cpu() toc = tictoc() if toc > t_request + args.max_latency and not slow_warning: print( f'model is too slow and can\'t handle {args.rps} requests per second!' ) slow_warning = True latency_times.append(toc - t_request) latency_times = torch.FloatTensor(latency_times) latency_times *= 1e3 # convert to ms print( f'Latency mean: {torch.mean(latency_times):.1f} ms, ' + f'median: {latency_times.quantile(.50):.1f} ms, ' + f'90-th percentile: {latency_times.quantile(.90):.1f} ms, ' + f'95-th percentile: {latency_times.quantile(.95):.1f} ms, ' + f'99-th percentile: {latency_times.quantile(.99):.1f} ms, ' + f'max: {latency_times.max():.1f} ms | ' + f'service idle time fraction: {sum(idle_times) / args.benchmark_duration:.1%}' )
device_id = 0, element_type = batch.dtype, shape = batch_ortvalue.shape(), buffer_ptr = batch_ortvalue.data_ptr() ) io_binding.bind_output(args.onnx_output_node_name, device) return io_binding load_batch = lambda batch: load_batch_to_device(batch, args.device) model = lambda io_binding: onnxruntime_session.run_with_iobinding(io_binding) else: frontend = models.LogFilterBankFrontend( args.num_input_features, args.sample_rate, args.window_size, args.window_stride, args.window, stft_mode = args.stft_mode, ) if args.frontend else None model = getattr(models, args.model)( args.num_input_features, [len(labels)], frontend = frontend, dict = lambda logits, log_probs, olen, **kwargs: logits[0] ) if checkpoint: model.load_state_dict(checkpoint['model_state_dict'], strict = False) model.to(args.device)