def __init__(self, args, task): super().__init__(args, task) dictionary = task.target_dictionary self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) self.num_updates = -1 self.epoch = 0
def build_model(self, args): model = super().build_model(args) # build the greedy decoder for validation with WER from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder self.decoder_for_validation = SimpleGreedyDecoder( [model], self.target_dictionary, for_validation=True, ) return model
def __init__(self, args, task): super().__init__(args, task) dictionary = task.target_dictionary self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) self.num_updates = -1 self.epoch = 0 self.unigram_tensor = None if args.smoothing_type == 'unigram': self.unigram_tensor = torch.cuda.FloatTensor(dictionary.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ else torch.FloatTensor(dictionary.count).unsqueeze(-1) self.unigram_tensor += args.unigram_pseudo_count # for further backoff self.unigram_tensor.div_(self.unigram_tensor.sum())
def build_model(self, cfg: DictConfig, from_checkpoint=False): model = super().build_model(cfg, from_checkpoint) # build a greedy decoder for validation with WER if self.cfg.criterion_name == "transducer_loss": # a transducer model from espresso.tools.transducer_greedy_decoder import TransducerGreedyDecoder self.decoder_for_validation = TransducerGreedyDecoder( [model], self.target_dictionary, max_num_expansions_per_step=self.cfg.max_num_expansions_per_step, ) elif self.cfg.criterion_name == "ctc": # a ctc model raise NotImplementedError else: # assume it is an attention-based encoder-decoder model from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder self.decoder_for_validation = SimpleGreedyDecoder( [model], self.target_dictionary, for_validation=True, ) return model