raw_wav = configs['raw_wav'] # Init dataset and data loader # Init dataset and data loader test_collate_conf = copy.deepcopy(configs['collate_conf']) test_collate_conf['spec_aug'] = False test_collate_conf['spec_sub'] = False test_collate_conf['feature_dither'] = False test_collate_conf['speed_perturb'] = False test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav) dataset_conf = configs.get('dataset_conf', {}) dataset_conf['batch_size'] = args.batch_size dataset_conf['batch_type'] = 'static' dataset_conf['sort'] = False test_dataset = AudioDataset(args.test_data, **dataset_conf, raw_wav=raw_wav) test_data_loader = DataLoader(test_dataset, collate_fn=test_collate_func, shuffle=False, batch_size=1, num_workers=0) # Init asr model from configs model = init_asr_model(configs) # Load dict char_dict = {} with open(args.dict, 'r') as fin: for line in fin: arr = line.strip().split()
train_collate_func = CollateFunc(**configs['collate_conf'], raw_wav=raw_wav) cv_collate_conf = copy.deepcopy(configs['collate_conf']) # no augmenation on cv set cv_collate_conf['spec_aug'] = False cv_collate_conf['spec_sub'] = False if raw_wav: cv_collate_conf['feature_dither'] = 0.0 cv_collate_conf['speed_perturb'] = False cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav) dataset_conf = configs.get('dataset_conf', {}) train_dataset = AudioDataset(args.train_data, **dataset_conf, raw_wav=raw_wav) cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav) if distributed: logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) dist.init_process_group(args.dist_backend, init_method=args.init_method, world_size=args.world_size, rank=args.rank) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, shuffle=True) cv_sampler = torch.utils.data.distributed.DistributedSampler( cv_dataset, shuffle=False) else: train_sampler = None
raw_wav = configs['raw_wav'] # Init dataset and data loader ali_collate_conf = copy.deepcopy(configs['collate_conf']) ali_collate_conf['spec_aug'] = False ali_collate_conf['spec_sub'] = False ali_collate_conf['feature_dither'] = False ali_collate_conf['speed_perturb'] = False if raw_wav: ali_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 ali_collate_func = CollateFunc(**ali_collate_conf, raw_wav=raw_wav) dataset_conf = configs.get('dataset_conf', {}) dataset_conf['batch_size'] = args.batch_size dataset_conf['batch_type'] = 'static' dataset_conf['sort'] = False ali_dataset = AudioDataset(args.input_file, **dataset_conf, raw_wav=raw_wav) ali_data_loader = DataLoader(ali_dataset, collate_fn=ali_collate_func, shuffle=False, batch_size=1, num_workers=0) # Init asr model from configs model = init_asr_model(configs) load_checkpoint(model, args.checkpoint) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) model.eval()
logging.fatal( 'decoding mode {} must be running with batch_size == 1'.format( args.mode)) sys.exit(1) with open(args.config, 'r') as fin: configs = yaml.load(fin) # Init dataset and data loader test_collate_conf = copy.copy(configs['collate_conf']) test_collate_conf['spec_aug'] = False test_collate_func = CollateFunc(**test_collate_conf, cmvn=args.cmvn) dataset_conf = configs.get('dataset_conf', {}) dataset_conf['batch_size'] = args.batch_size dataset_conf['sort'] = False test_dataset = AudioDataset(args.test_data, **dataset_conf) test_data_loader = DataLoader(test_dataset, collate_fn=test_collate_func, shuffle=False, batch_size=1, num_workers=0) # Init transformer model input_dim = test_dataset.input_dim vocab_size = test_dataset.output_dim encoder_type = configs.get('encoder', 'conformer') if encoder_type == 'conformer': encoder = ConformerEncoder(input_dim, **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, **configs['encoder_conf']) decoder = TransformerDecoder(vocab_size, encoder.output_size(),
torch.manual_seed(777) with open(args.config, 'r') as fin: configs = yaml.load(fin) distributed = args.world_size > 1 # Init dataset and data loader collate_func = CollateFunc(**configs['collate_conf'], **configs['spec_aug_conf'], cmvn=args.cmvn) cv_collate_conf = copy.copy(configs['collate_conf']) cv_collate_conf['spec_aug'] = False cv_collate_func = CollateFunc(**cv_collate_conf, cmvn=args.cmvn) dataset_conf = configs.get('dataset_conf', {}) train_dataset = AudioDataset(args.train_data, **dataset_conf) cv_dataset = AudioDataset(args.cv_data, **dataset_conf) if distributed: logging.info('training on multiple gpu, this gpu {}'.format(args.gpu)) dist.init_process_group(args.dist_backend, init_method=args.init_method, world_size=args.world_size, rank=args.rank) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, shuffle=True) cv_sampler = torch.utils.data.distributed.DistributedSampler( cv_dataset, shuffle=False) else: train_sampler = None cv_sampler = None