예제 #1
0
def setup(args):
    torch.set_grad_enabled(False)
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map(
        checkpoint['args'].get, [
            'sample_rate', 'window_size', 'window_stride', 'window',
            'num_input_features'
        ])
    frontend = models.LogFilterBankFrontend(
        args.num_input_features,
        args.sample_rate,
        args.window_size,
        args.window_stride,
        args.window,
        dither=args.dither,
        dither0=args.dither0,
        #eps = 1e-6,
        normalize_signal=args.normalize_signal,
        debug_short_long_records_normalize_signal_multiplier=args.
        debug_short_long_records_normalize_signal_multiplier)

    # for legacy compat
    text_config = json.load(
        open(checkpoint['args'].get('text_config', args.text_config)))
    text_pipeline = text_processing.ProcessingPipeline.make(
        text_config, checkpoint['args'].get('text_pipelines',
                                            args.text_pipelines)[0])

    model = getattr(models, args.model or checkpoint['args']['model'])(
        args.num_input_features, [text_pipeline.tokenizer.vocab_size],
        frontend=frontend if args.frontend_in_model else None,
        check_time_dim_padded=False,
        dict=lambda logits, log_probs, olen, **kwargs:
        (log_probs[0], logits[0], olen[0]))
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model = model.to(args.device)
    model.eval()
    model.fuse_conv_bn_eval()
    if args.device != 'cpu':
        model, *_ = models.data_parallel_and_autocast(model,
                                                      opt_level=args.fp16)
    generator = transcript_generators.GreedyCTCGenerator()
    return text_pipeline, frontend, model, generator
예제 #2
0
def setup(args):
    torch.set_grad_enabled(False)
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = map(
        checkpoint['args'].get, [
            'sample_rate', 'window_size', 'window_stride', 'window',
            'num_input_features'
        ])
    frontend = models.LogFilterBankFrontend(args.num_input_features,
                                            args.sample_rate,
                                            args.window_size,
                                            args.window_stride,
                                            args.window,
                                            eps=1e-6)
    labels = datasets.Labels(datasets.Language(checkpoint['args']['lang']),
                             name='char')
    model = getattr(models, args.model or checkpoint['args']['model'])(
        args.num_input_features, [len(labels)],
        frontend=frontend,
        dict=lambda logits, log_probs, olen, **kwargs: (logits[0], olen[0]))
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model = model.to(args.device)
    model.eval()
    model.fuse_conv_bn_eval()
    if args.device != 'cpu':
        model, *_ = models.data_parallel_and_autocast(model,
                                                      opt_level=args.fp16)
    decoder = decoders.GreedyDecoder(
    ) if args.decoder == 'GreedyDecoder' else decoders.BeamSearchDecoder(
        labels,
        lm_path=args.lm,
        beam_width=args.beam_width,
        beam_alpha=args.beam_alpha,
        beam_beta=args.beam_beta,
        num_workers=args.num_workers,
        topk=args.decoder_topk)
    return labels, frontend, model, decoder
예제 #3
0
def main(args):
    use_cuda = 'cuda' in args.device
    if use_cuda:
        print(
            f'CUDA_VISIBLE_DEVICES: {os.environ.get("CUDA_VISIBLE_DEVICES")!r}'
        )
    print('initializing model...')

    if args.onnx:
        # todo: pass dict with provider setting when we will migrate to onnxruntime>=1.7
        onnxruntime_session = onnxruntime.InferenceSession(args.onnx)
        if args.device == 'cpu':
            onnxruntime_session.set_providers(['CPUExecutionProvider'])
        model = lambda x: onnxruntime_session.run(
            None, dict(x=x, xlen=[1.0] * len(x)))
        load_batch = lambda x: x.numpy()
    else:
        assert args.checkpoint is not None and os.path.isfile(args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location='cpu')
        args.model, args.lang, args.sample_rate, args.window_size, args.window_stride, args.window, args.num_input_features = (
            checkpoint['args'].get(key) for key in [
                'model', 'lang', 'sample_rate', 'window_size', 'window_stride',
                'window', 'num_input_features'
            ])

        text_config = json.load(open(args.text_config))
        text_pipelines = []
        for pipeline_name in args.text_pipelines:
            text_pipelines.append(
                text_processing.ProcessingPipeline.make(
                    text_config, pipeline_name))

        frontend = models.LogFilterBankFrontend(
            out_channels=args.num_input_features,
            sample_rate=args.sample_rate,
            window_size=args.window_size,
            window_stride=args.window_stride,
            window=args.window,
            stft_mode=args.stft_mode,
        )
        model = getattr(models, args.model)(
            num_input_features=args.num_input_features,
            num_classes=[
                pipeline.tokenizer.vocab_size for pipeline in text_pipelines
            ],
            frontend=frontend,
            dict=lambda logits, log_probs, olen, **kwargs: logits[0])
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        model.to(args.device)
        model.eval()
        model.fuse_conv_bn_eval()
        if use_cuda:
            model, *_ = models.data_parallel_and_autocast(model,
                                                          opt_level=args.fp16,
                                                          data_parallel=False)
        load_batch = lambda x: x.to(args.device, non_blocking=True)

    batch_width = int(math.ceil(args.T * args.sample_rate / 128) * 128)
    example_time = batch_width / args.sample_rate
    batch = torch.rand(args.B, batch_width)
    batch = batch.pin_memory()

    print(
        f'batch [{args.B}, {batch_width}] | audio {args.B * example_time:.2f} sec\n'
    )

    def tictoc():
        if use_cuda:
            torch.cuda.synchronize()
        return time.time()

    print(f'Warming up for {args.warmup_iterations} iterations...')
    tic_wall = tictoc()
    if use_cuda:
        torch.backends.cudnn.benchmark = True
    torch.set_grad_enabled(False)  # no back-prop - no gradients
    for i in range(args.warmup_iterations):
        model(load_batch(batch))
    print(f'Warmup done in {tictoc() - tic_wall:.1f} sec\n')

    n_requests = int(round(args.benchmark_duration * args.rps))
    print(
        f'Starting {args.benchmark_duration} second benchmark ({n_requests} requests, rps {args.rps:.1f})...'
    )
    schedule = torch.rand(
        n_requests,
        dtype=torch.float64)  # uniform random distribution of requests
    schedule = schedule.sort()[0] * args.benchmark_duration + tictoc()
    gaps = schedule[1:] - schedule[:-1]
    print(f'avg gap between requests: {gaps.mean() * 1e3:.1f} ms')
    latency_times, idle_times = [], []
    slow_warning = False
    for t_request in schedule:
        t_request = t_request.item()
        tic = tictoc()
        if tic < t_request:  # no requests yet. at this point prod would wait for the next request
            sleep_time = t_request - tic
            idle_times.append(sleep_time)
            time.sleep(sleep_time)
        logits = model(load_batch(batch))
        if not args.onnx:
            logits.cpu()
        toc = tictoc()
        if toc > t_request + args.max_latency and not slow_warning:
            print(
                f'model is too slow and can\'t handle {args.rps} requests per second!'
            )
            slow_warning = True
        latency_times.append(toc - t_request)
    latency_times = torch.FloatTensor(latency_times)
    latency_times *= 1e3  # convert to ms
    print(
        f'Latency mean: {torch.mean(latency_times):.1f} ms, ' +
        f'median: {latency_times.quantile(.50):.1f} ms, ' +
        f'90-th percentile: {latency_times.quantile(.90):.1f} ms, ' +
        f'95-th percentile: {latency_times.quantile(.95):.1f} ms, ' +
        f'99-th percentile: {latency_times.quantile(.99):.1f} ms, ' +
        f'max: {latency_times.max():.1f} ms | ' +
        f'service idle time fraction: {sum(idle_times) / args.benchmark_duration:.1%}'
    )
예제 #4
0
			device_id = 0,
			element_type = batch.dtype,
			shape = batch_ortvalue.shape(),
			buffer_ptr = batch_ortvalue.data_ptr()
		)
		io_binding.bind_output(args.onnx_output_node_name, device)
		return io_binding

	load_batch = lambda batch: load_batch_to_device(batch, args.device)
	model = lambda io_binding: onnxruntime_session.run_with_iobinding(io_binding)

else:
	frontend = models.LogFilterBankFrontend(
		args.num_input_features,
		args.sample_rate,
		args.window_size,
		args.window_stride,
		args.window,
		stft_mode = args.stft_mode,
	) if args.frontend else None
	model = getattr(models, args.model)(
		args.num_input_features, [len(labels)],
		frontend = frontend,
		dict = lambda logits,
		log_probs,
		olen,
		**kwargs: logits[0]
	)
	if checkpoint:
		model.load_state_dict(checkpoint['model_state_dict'], strict = False)
	model.to(args.device)