class CER(Metric): def __init__(self, name='CER', output_transform=lambda x: x[1:], fmt='.02%', **kwargs): alphabet = kwargs.pop('alphabet', None) super().__init__(name, output_transform, fmt=fmt, **kwargs) if alphabet is None: raise ValueError('Arg. `alphabet` is required.') self.decoder = GreedyCTCDecoder(alphabet) def _update(self, output): output, target, output_lengths, target_lengths = output transcripts = self.decoder.decode(output, output_lengths) references = self.decoder.tensor2str(target, target_lengths) super()._update( torch.tensor([ self.decoder.cer(t, r) / float(len(r)) if len(r) else 1 for t, r in zip(transcripts[0], references) ], device=output.device))
def __init__(self, name='CER', output_transform=lambda x: x[1:], fmt='.02%', **kwargs): alphabet = kwargs.pop('alphabet', None) super().__init__(name, output_transform, fmt=fmt, **kwargs) if alphabet is None: raise ValueError('Arg. `alphabet` is required.') self.decoder = GreedyCTCDecoder(alphabet)
def evaluate_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] print(basename) logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0 for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) transcripts = decoder.decode(output)[0] logits.extend( (o[:l, ...], r) for o, l, r in zip(output, output_lengths, references)) del sample, sample_lengths, output for reference, transcript in zip(references, transcripts): total_wer += decoder.wer(transcript, reference) total_cer += decoder.cer(transcript, reference) num_tokens += float(len(reference.split())) num_chars += float(len(reference)) torch.save(logits, logits_file) wer = float(total_wer) / num_tokens cer = float(total_cer) / num_chars print(f'WER: {wer:.02%}\nCER: {cer:.02%}') del model else: logger.info(f'Logits file `{logits_file}` already generated.')
def tune_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) logits.extend((o[:l, ...], r) for o, l, r in zip( output.to('cpu'), output_lengths, references)) del sample, sample_lengths, output torch.save(logits, logits_file) del model tune_dir = os.path.join(args.serialization_dir, 'tune') os.makedirs(tune_dir, exist_ok=True) params_grid = list( product( torch.linspace(args.alpha_from, args.alpha_to, args.alpha_steps), torch.linspace(args.beta_from, args.beta_to, args.beta_steps))) print( 'Scheduling {} jobs for alphas=linspace({}, {}, {}), betas=linspace({}, {}, {})' .format(len(params_grid), args.alpha_from, args.alpha_to, args.alpha_steps, args.beta_from, args.beta_to, args.beta_steps)) # start worker processes logger.info( f"Using {args.num_workers} processes and {args.lm_workers} for each CTCDecoder." ) extract_start = default_timer() p = Pool(args.num_workers, init, [ logits_file, alphabet, args.lm_path, args.cutoff_top_n, args.cutoff_prob, args.beam_width, args.lm_workers ]) scores = [] best_wer = float('inf') with tqdm.tqdm(p.imap(tune_step, params_grid), total=len(params_grid), desc='Grid search') as pbar: for params in pbar: alpha, beta, wer, cer = params scores.append([alpha, beta, wer, cer]) if wer < best_wer: best_wer = wer pbar.set_postfix(alpha=alpha, beta=beta, wer=wer, cer=cer) logger.info( f"Finished {len(params_grid)} processes in {default_timer() - extract_start:.1f}s" ) df = pd.DataFrame(scores, columns=['alpha', 'beta', 'wer', 'cer']) df.to_csv(os.path.join(tune_dir, basename + '.csv'), index=False)
def test_greedy_decoder(): """ Code adapted from tensorflow """ max_time_steps = 6 seq_len_0 = 4 input_prob_matrix_0 = torch.tensor( [ [1.0, 0.0, 0.0, 0.0], # t=0 [0.0, 0.0, 0.4, 0.6], # t=1 [0.0, 0.0, 0.4, 0.6], # t=2 [0.0, 0.9, 0.1, 0.0], # t=3 [0.0, 0.0, 0.0, 0.0], # t=4 (ignored) [0.0, 0.0, 0.0, 0.0] ], # t=5 (ignored) dtype=torch.float32) input_log_prob_matrix_0 = input_prob_matrix_0.log() seq_len_1 = 5 # dimensions are time x depth input_prob_matrix_1 = torch.tensor( [ [0.1, 0.9, 0.0, 0.0], # t=0 [0.0, 0.9, 0.1, 0.0], # t=1 [0.0, 0.0, 0.1, 0.9], # t=2 [0.0, 0.9, 0.1, 0.1], # t=3 [0.9, 0.1, 0.0, 0.0], # t=4 [0.0, 0.0, 0.0, 0.0] ], # t=5 (ignored) dtype=torch.float32) input_log_prob_matrix_1 = input_prob_matrix_1.log() # len max_time_steps array of batch_size x depth matrices inputs = torch.stack([input_log_prob_matrix_0, input_log_prob_matrix_1]) # batch_size length vector of sequence_lengths seq_lens = torch.tensor([seq_len_0, seq_len_1], dtype=torch.int32) # batch_size length vector of negative log probabilities log_prob_truth = torch.tensor([ -(torch.tensor([1.0, 0.6, 0.6, 0.9]).log()).sum().item(), -(torch.tensor([0.9, 0.9, 0.9, 0.9, 0.9]).log()).sum().item() ]) decode_truth = ['ab', 'bba'] offsets_truth = [ torch.tensor([0, 3]), torch.tensor([0, 3, 4]), ] alphabet = Alphabet('abc-', blank_index=3) decoder = GreedyCTCDecoder(alphabet) out, scores, offsets = decoder.decode(inputs, seq_lens) assert out[0] == decode_truth[0] assert out[1] == decode_truth[1] assert torch.allclose(scores, log_prob_truth) assert torch.all(offsets[0] == offsets_truth[0]) assert torch.all(offsets[1] == offsets_truth[1])