class WER(Metric): def __init__(self, name='WER', output_transform=lambda x: x[1:], fmt='.02%', **kwargs): alphabet = kwargs.pop('alphabet', None) super().__init__(name, output_transform, fmt=fmt, **kwargs) if alphabet is None: raise ValueError('Arg. `alphabet` is required.') self.decoder = GreedyCTCDecoder(alphabet) def _update(self, output): output, target, output_lengths, target_lengths = output transcripts = self.decoder.decode(output, output_lengths) references = self.decoder.tensor2str(target, target_lengths) super()._update( torch.tensor([ self.decoder.wer(t, r) / float(len(r.split())) if len(r.split()) else 1 for t, r in zip(transcripts[0], references) ], device=output.device))
def evaluate_from_args(args): # Disable some of the more verbose logging statements logging.getLogger('asr.common.params').disabled = True logging.getLogger('asr.common.registrable').disabled = True # Load from archive _, weights_file = load_archive(args.serialization_dir, args.overrides, args.weights_file) params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME), args.overrides) prepare_environment(params) # Try to use the validation dataset reader if there is one - otherwise fall back # to the default dataset_reader used for both training and validation. dataset_params = params.pop('val_dataset', params.get('dataset_reader')) logger.info("Reading evaluation data from %s", args.input_file) dataset_params['manifest_filepath'] = args.input_file dataset = datasets.from_params(dataset_params) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file( os.path.join(args.serialization_dir, "alphabet", "tokens")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) logits_dir = os.path.join(args.serialization_dir, 'logits') os.makedirs(logits_dir, exist_ok=True) basename = os.path.splitext(os.path.split(args.input_file)[1])[0] print(basename) logits_file = os.path.join(logits_dir, basename + '.pth') if not os.path.exists(logits_file): model = models.from_params(alphabet=alphabet, params=params.pop('model')) model.load_state_dict( torch.load(weights_file, map_location=lambda storage, loc: storage)['model']) model.eval() decoder = GreedyCTCDecoder(alphabet) loader_params = params.pop("val_data_loader", params.get("data_loader")) batch_sampler = samplers.BucketingSampler(dataset, batch_size=args.batch_size) loader = loaders.from_params(loader_params, dataset=dataset, batch_sampler=batch_sampler) logger.info(f'Logits file `{logits_file}` not found. Generating...') with torch.no_grad(): model.to(args.device) logits = [] total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0 for batch in tqdm.tqdm(loader): sample, target, sample_lengths, target_lengths = batch sample = sample.to(args.device) sample_lengths = sample_lengths.to(args.device) output, output_lengths = model(sample, sample_lengths) output = output.to('cpu') references = decoder.tensor2str(target, target_lengths) transcripts = decoder.decode(output)[0] logits.extend( (o[:l, ...], r) for o, l, r in zip(output, output_lengths, references)) del sample, sample_lengths, output for reference, transcript in zip(references, transcripts): total_wer += decoder.wer(transcript, reference) total_cer += decoder.cer(transcript, reference) num_tokens += float(len(reference.split())) num_chars += float(len(reference)) torch.save(logits, logits_file) wer = float(total_wer) / num_tokens cer = float(total_cer) / num_chars print(f'WER: {wer:.02%}\nCER: {cer:.02%}') del model else: logger.info(f'Logits file `{logits_file}` already generated.')