def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) multi_gpu = args.local_rank is not None if args.cpu: assert (not multi_gpu) device = torch.device('cpu') else: assert (torch.cuda.is_available()) device = torch.device('cuda') torch.backends.cudnn.benchmark = args.cudnn_benchmark print("CUDNN BENCHMARK ", args.cudnn_benchmark) if multi_gpu: print("DISTRIBUTED with ", torch.distributed.get_world_size()) torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') optim_level = 3 if args.amp else 0 jasper_model_definition = toml.load(args.model_toml) dataset_vocab = jasper_model_definition['labels']['labels'] ctc_vocab = add_ctc_labels(dataset_vocab) val_manifest = args.val_manifest featurizer_config = jasper_model_definition['input_eval'] featurizer_config["optimization_level"] = optim_level featurizer_config["fp16"] = args.amp args.use_conv_mask = jasper_model_definition['encoder'].get( 'convmask', True) if args.use_conv_mask and args.export_model: print( 'WARNING: Masked convs currently not supported for TorchScript. Disabling.' ) jasper_model_definition['encoder']['convmask'] = False if args.max_duration is not None: featurizer_config['max_duration'] = args.max_duration if args.pad_to is not None: featurizer_config['pad_to'] = args.pad_to if featurizer_config['pad_to'] == "max": featurizer_config['pad_to'] = -1 print('=== model_config ===') print_dict(jasper_model_definition) print() print('=== feature_config ===') print_dict(featurizer_config) print() data_layer = None if args.wav is None: data_layer = AudioToTextDataLayer( dataset_dir=args.dataset_dir, featurizer_config=featurizer_config, manifest_filepath=val_manifest, labels=dataset_vocab, batch_size=args.batch_size, pad_to_max=featurizer_config['pad_to'] == -1, shuffle=False, multi_gpu=multi_gpu) audio_preprocessor = AudioPreprocessing(**featurizer_config) encoderdecoder = JasperEncoderDecoder( jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab)) if args.ckpt is not None: print("loading model from ", args.ckpt) if os.path.isdir(args.ckpt): exit(0) else: checkpoint = torch.load(args.ckpt, map_location="cpu") if args.ema and 'ema_state_dict' in checkpoint: print('Loading EMA state dict') sd = 'ema_state_dict' else: sd = 'state_dict' for k in audio_preprocessor.state_dict().keys(): checkpoint[sd][k] = checkpoint[sd].pop("audio_preprocessor." + k) audio_preprocessor.load_state_dict(checkpoint[sd], strict=False) encoderdecoder.load_state_dict(checkpoint[sd], strict=False) greedy_decoder = GreedyCTCDecoder() # print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights())) if args.wav is None: N = len(data_layer) step_per_epoch = math.ceil( N / (args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()))) if args.steps is not None: print('-----------------') print('Have {0} examples to eval on.'.format( args.steps * args.batch_size * (1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()))) print('Have {0} steps / (gpu * epoch).'.format(args.steps)) print('-----------------') else: print('-----------------') print('Have {0} examples to eval on.'.format(N)) print('Have {0} steps / (gpu * epoch).'.format(step_per_epoch)) print('-----------------') print("audio_preprocessor.normalize: ", audio_preprocessor.featurizer.normalize) audio_preprocessor.to(device) encoderdecoder.to(device) if args.amp: encoderdecoder = amp.initialize(models=encoderdecoder, opt_level='O' + str(optim_level)) encoderdecoder = model_multi_gpu(encoderdecoder, multi_gpu) audio_preprocessor.eval() encoderdecoder.eval() greedy_decoder.eval() eval(data_layer=data_layer, audio_processor=audio_preprocessor, encoderdecoder=encoderdecoder, greedy_decoder=greedy_decoder, labels=ctc_vocab, args=args, device=device, multi_gpu=multi_gpu)
def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) assert (args.steps is None or args.steps > 5) if args.cpu: device = torch.device('cpu') else: assert (torch.cuda.is_available()) device = torch.device('cuda') torch.backends.cudnn.benchmark = args.cudnn_benchmark print("CUDNN BENCHMARK ", args.cudnn_benchmark) optim_level = 3 if args.amp else 0 batch_size = args.batch_size jasper_model_definition = toml.load(args.model_toml) dataset_vocab = jasper_model_definition['labels']['labels'] ctc_vocab = add_ctc_labels(dataset_vocab) val_manifest = args.val_manifest featurizer_config = jasper_model_definition['input_eval'] featurizer_config["optimization_level"] = optim_level if args.max_duration is not None: featurizer_config['max_duration'] = args.max_duration # TORCHSCRIPT: Cant use mixed types. Using -1 for "max" if args.pad_to is not None: featurizer_config['pad_to'] = args.pad_to if args.pad_to >= 0 else -1 if featurizer_config['pad_to'] == "max": featurizer_config['pad_to'] = -1 args.use_conv_mask = jasper_model_definition['encoder'].get( 'convmask', True) if args.use_conv_mask and args.torch_script: print( 'WARNING: Masked convs currently not supported for TorchScript. Disabling.' ) jasper_model_definition['encoder']['convmask'] = False print('model_config') print_dict(jasper_model_definition) print('feature_config') print_dict(featurizer_config) data_layer = AudioToTextDataLayer( dataset_dir=args.dataset_dir, featurizer_config=featurizer_config, manifest_filepath=val_manifest, labels=dataset_vocab, batch_size=batch_size, pad_to_max=featurizer_config['pad_to'] == -1, shuffle=False, multi_gpu=False) audio_preprocessor = AudioPreprocessing(**featurizer_config) encoderdecoder = JasperEncoderDecoder( jasper_model_definition=jasper_model_definition, feat_in=1024, num_classes=len(ctc_vocab)) if args.ckpt is not None: print("loading model from ", args.ckpt) checkpoint = torch.load(args.ckpt, map_location="cpu") for k in audio_preprocessor.state_dict().keys(): checkpoint['state_dict'][k] = checkpoint['state_dict'].pop( "audio_preprocessor." + k) audio_preprocessor.load_state_dict(checkpoint['state_dict'], strict=False) encoderdecoder.load_state_dict(checkpoint['state_dict'], strict=False) greedy_decoder = GreedyCTCDecoder() # print("Number of parameters in encoder: {0}".format(model.jasper_encoder.num_weights())) N = len(data_layer) step_per_epoch = math.ceil(N / args.batch_size) print('-----------------') if args.steps is None: print('Have {0} examples to eval on.'.format(N)) print('Have {0} steps / (epoch).'.format(step_per_epoch)) else: print('Have {0} examples to eval on.'.format(args.steps * args.batch_size)) print('Have {0} steps / (epoch).'.format(args.steps)) print('-----------------') audio_preprocessor.to(device) encoderdecoder.to(device) if args.amp: encoderdecoder = amp.initialize(models=encoderdecoder, opt_level='O' + str(optim_level)) eval(data_layer=data_layer, audio_processor=audio_preprocessor, encoderdecoder=encoderdecoder, greedy_decoder=greedy_decoder, labels=ctc_vocab, device=device, args=args)