def init_acoustic_model(args): from common.helpers import add_ctc_blank from jasper.model import Jasper from jasper import config cfg = config.load(args.model_config) config.apply_config_overrides(cfg, args) if cfg['jasper']['encoder']['use_conv_masks'] == True: print("[Jasper module]: Warning: setting 'use_conv_masks' \ to False; masked convolutions are not supported.") cfg['jasper']['encoder']['use_conv_masks'] = False symbols = add_ctc_blank(cfg['labels']) model = Jasper(encoder_kw=config.encoder(cfg), decoder_kw=config.decoder(cfg, n_classes=len(symbols))) if args.ckpt is not None: checkpoint = torch.load(args.ckpt, map_location="cpu") key = 'ema_state_dict' if args.ema else 'state_dict' state_dict = checkpoint[key] model.load_state_dict(state_dict, strict=True) return model
def init_acoustic_model(args): from common.helpers import add_ctc_blank from jasper.model import Jasper from jasper import config cfg = config.load(args.model_config) if args.max_duration is not None: cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration cfg['input_val']['filterbank_features'][ 'max_duration'] = args.max_duration if args.pad_to_max_duration: assert cfg['input_train']['audio_dataset']['max_duration'] > 0 cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True if cfg['jasper']['encoder']['use_conv_masks'] == True: print("[Jasper module]: Warning: setting 'use_conv_masks' \ to False; masked convolutions are not supported.") cfg['jasper']['encoder']['use_conv_masks'] = False symbols = add_ctc_blank(cfg['labels']) model = Jasper(encoder_kw=config.encoder(cfg), decoder_kw=config.decoder(cfg, n_classes=len(symbols))) if args.ckpt is not None: checkpoint = torch.load(args.ckpt, map_location="cpu") key = 'ema_state_dict' if args.ema else 'state_dict' state_dict = checkpoint[key] model.load_state_dict(state_dict, strict=True) return model
def main(): parser = get_parser() args = parser.parse_args() log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json')) log_fpath = unique_log_fpath(log_fpath) dllogger.init(backends=[ JSONStreamBackend(Verbosity.DEFAULT, log_fpath), StdOutBackend(Verbosity.VERBOSE, metric_format=stdout_metric_format) ]) [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()] for step in ['DNN', 'data+DNN', 'data']: for c in [0.99, 0.95, 0.9, 0.5]: cs = 'avg' if c == 0.5 else f'{int(100*c)}%' dllogger.metadata(f'{step.lower()}_latency_{c}', { 'name': f'{step} latency {cs}', 'format': ':>7.2f', 'unit': 'ms' }) dllogger.metadata('eval_wer', { 'name': 'WER', 'format': ':>3.2f', 'unit': '%' }) if args.cpu: device = torch.device('cpu') else: assert torch.cuda.is_available() device = torch.device('cuda') torch.backends.cudnn.benchmark = args.cudnn_benchmark if args.seed is not None: torch.manual_seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) random.seed(args.seed + args.local_rank) # set up distributed training multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1 if multi_gpu: torch.cuda.set_device(args.local_rank) distrib.init_process_group(backend='nccl', init_method='env://') print_once(f'Inference with {distrib.get_world_size()} GPUs') cfg = config.load(args.model_config) if args.max_duration is not None: cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration cfg['input_val']['filterbank_features'][ 'max_duration'] = args.max_duration if args.pad_to_max_duration: assert cfg['input_val']['audio_dataset']['max_duration'] > 0 cfg['input_val']['audio_dataset']['pad_to_max_duration'] = True cfg['input_val']['filterbank_features']['pad_to_max_duration'] = True symbols = helpers.add_ctc_blank(cfg['labels']) use_dali = args.dali_device in ('cpu', 'gpu') dataset_kw, features_kw = config.input(cfg, 'val') measure_perf = args.steps > 0 # dataset if args.transcribe_wav or args.transcribe_filelist: if use_dali: print("DALI supported only with input .json files; disabling") use_dali = False assert not args.pad_to_max_duration assert not (args.transcribe_wav and args.transcribe_filelist) if args.transcribe_wav: dataset = SingleAudioDataset(args.transcribe_wav) else: dataset = FilelistDataset(args.transcribe_filelist) data_loader = get_data_loader( dataset, batch_size=1, multi_gpu=multi_gpu, shuffle=False, num_workers=0, drop_last=(True if measure_perf else False)) _, features_kw = config.input(cfg, 'val') feat_proc = FilterbankFeatures(**features_kw) elif use_dali: # pad_to_max_duration is not supported by DALI - have simple padders if features_kw['pad_to_max_duration']: feat_proc = BaseFeatures( pad_align=features_kw['pad_align'], pad_to_max_duration=True, max_duration=features_kw['max_duration'], sample_rate=features_kw['sample_rate'], window_size=features_kw['window_size'], window_stride=features_kw['window_stride']) features_kw['pad_to_max_duration'] = False else: feat_proc = None data_loader = DaliDataLoader( gpu_id=args.local_rank or 0, dataset_path=args.dataset_dir, config_data=dataset_kw, config_features=features_kw, json_names=args.val_manifests, batch_size=args.batch_size, pipeline_type=("train" if measure_perf else "val"), # no drop_last device_type=args.dali_device, symbols=symbols) else: dataset = AudioDataset(args.dataset_dir, args.val_manifests, symbols, **dataset_kw) data_loader = get_data_loader(dataset, args.batch_size, multi_gpu=multi_gpu, shuffle=False, num_workers=4, drop_last=False) feat_proc = FilterbankFeatures(**features_kw) model = Jasper(encoder_kw=config.encoder(cfg), decoder_kw=config.decoder(cfg, n_classes=len(symbols))) if args.ckpt is not None: print(f'Loading the model from {args.ckpt} ...') checkpoint = torch.load(args.ckpt, map_location="cpu") key = 'ema_state_dict' if args.ema else 'state_dict' state_dict = helpers.convert_v1_state_dict(checkpoint[key]) model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() if feat_proc is not None: feat_proc.to(device) feat_proc.eval() if args.amp: model = model.half() if args.torchscript: greedy_decoder = GreedyCTCDecoder() feat_proc, model, greedy_decoder = torchscript_export( data_loader, feat_proc, model, greedy_decoder, args.output_dir, use_amp=args.amp, use_conv_masks=True, model_toml=args.model_toml, device=device, save=args.torchscript_export) if multi_gpu: model = DistributedDataParallel(model) agg = {'txts': [], 'preds': [], 'logits': []} dur = {'data': [], 'dnn': [], 'data+dnn': []} looped_loader = chain.from_iterable(repeat(data_loader)) greedy_decoder = GreedyCTCDecoder() sync = lambda: torch.cuda.synchronize() if device.type == 'cuda' else None steps = args.steps + args.warmup_steps or len(data_loader) with torch.no_grad(): for it, batch in enumerate(tqdm(looped_loader, initial=1, total=steps)): if use_dali: feats, feat_lens, txt, txt_lens = batch if feat_proc is not None: feats, feat_lens = feat_proc(feats, feat_lens) else: batch = [t.to(device, non_blocking=True) for t in batch] audio, audio_lens, txt, txt_lens = batch feats, feat_lens = feat_proc(audio, audio_lens) sync() t1 = time.perf_counter() if args.amp: feats = feats.half() if model.encoder.use_conv_masks: log_probs, log_prob_lens = model(feats, feat_lens) else: log_probs = model(feats, feat_lens) preds = greedy_decoder(log_probs) sync() t2 = time.perf_counter() # burn-in period; wait for a new loader due to num_workers if it >= 1 and (args.steps == 0 or it >= args.warmup_steps): dur['data'].append(t1 - t0) dur['dnn'].append(t2 - t1) dur['data+dnn'].append(t2 - t0) if txt is not None: agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], symbols) agg['preds'] += helpers.gather_predictions([preds], symbols) agg['logits'].append(log_probs) if it + 1 == steps: break sync() t0 = time.perf_counter() # communicate the results if args.transcribe_wav: for idx, p in enumerate(agg['preds']): print_once(f'Prediction {idx+1: >3}: {p}') elif args.transcribe_filelist: pass elif not multi_gpu or distrib.get_rank() == 0: wer, _ = process_evaluation_epoch(agg) dllogger.log(step=(), data={'eval_wer': 100 * wer}) if args.save_predictions: with open(args.save_predictions, 'w') as f: f.write('\n'.join(agg['preds'])) if args.save_logits: logits = torch.cat(agg['logits'], dim=0).cpu() torch.save(logits, args.save_logits) # report timings if len(dur['data']) >= 20: ratios = [0.9, 0.95, 0.99] for stage in dur: lat = durs_to_percentiles(dur[stage], ratios) for k in [0.99, 0.95, 0.9, 0.5]: kk = str(k).replace('.', '_') dllogger.log(step=(), data={f'{stage.lower()}_latency_{kk}': lat[k]}) else: print_once('Not enough samples to measure latencies.')