def summarize(self, field=None): """Summarize the error_rate and return relevant statistics. * See MetricStats.summarize() """ self.summary = wer_summary(self.scores) # Add additional, more generic key self.summary["error_rate"] = self.summary["WER"] if field is not None: return self.summary[field] else: return self.summary
default="<eps>", help="When printing alignments, empty spaces are filled with this.", ) parser.add_argument( "--utt2spk", help="Provide a mapping from utterance ids to speaker ids." "If provided, print a list of speakers with highest WER.", ) args = parser.parse_args() details_by_utterance = edit_distance.wer_details_by_utterance( _plain_text_keydict(args.ref), _plain_text_keydict(args.hyp), compute_alignments=args.print_alignments, scoring_mode=args.mode, ) summary_details = edit_distance.wer_summary(details_by_utterance) wer_io.print_wer_summary(summary_details) if args.print_top_wer: top_non_empty, top_empty = edit_distance.top_wer_utts( details_by_utterance) wer_io._print_top_wer_utts(top_non_empty, top_empty) if args.utt2spk: utt2spk = _utt2spk_keydict(args.utt2spk) details_by_speaker = edit_distance.wer_details_by_speaker( details_by_utterance, utt2spk) top_spks = edit_distance.top_wer_spks(details_by_speaker) wer_io._print_top_wer_spks(top_spks) if args.print_alignments: wer_io.print_alignments( details_by_utterance, empty_symbol=args.align_empty,
) train_set = params.train_loader() valid_set = params.valid_loader() first_x, first_y = next(iter(train_set)) if hasattr(params, "augmentation"): modules.append(params.augmentation) asr_brain = ASR( modules=modules, optimizer=params.optimizer, first_inputs=[first_x, first_y], ) # Load latest checkpoint to resume training checkpointer.recover_if_possible() asr_brain.fit(params.epoch_counter, train_set, valid_set) # Load best checkpoint for evaluation checkpointer.recover_if_possible(lambda c: -c.meta["PER"]) test_stats = asr_brain.evaluate(params.test_loader()) params.train_logger.log_stats( stats_meta={"Epoch loaded": params.epoch_counter.current}, test_stats=test_stats, ) # Write alignments to file per_summary = edit_distance.wer_summary(test_stats["PER"]) with open(params.wer_file, "w") as fo: wer_io.print_wer_summary(per_summary, fo) wer_io.print_alignments(test_stats["PER"], fo)