예제 #1
0
    def __init__(self, args, task):
        super().__init__(args, task)

        dictionary = task.target_dictionary
        self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter)
        self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True)
        self.num_updates = -1
        self.epoch = 0
예제 #2
0
 def build_model(self, args):
     model = super().build_model(args)
     # build the greedy decoder for validation with WER
     from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder
     self.decoder_for_validation = SimpleGreedyDecoder(
         [model], self.target_dictionary, for_validation=True,
     )
     return model
    def __init__(self, args, task):
        super().__init__(args, task)

        dictionary = task.target_dictionary
        self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter)
        self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True)
        self.num_updates = -1
        self.epoch = 0
        self.unigram_tensor = None
        if args.smoothing_type == 'unigram':
            self.unigram_tensor = torch.cuda.FloatTensor(dictionary.count).unsqueeze(-1) \
                if torch.cuda.is_available() and not args.cpu \
                else torch.FloatTensor(dictionary.count).unsqueeze(-1)
            self.unigram_tensor += args.unigram_pseudo_count  # for further backoff
            self.unigram_tensor.div_(self.unigram_tensor.sum())
예제 #4
0
    def build_model(self, cfg: DictConfig, from_checkpoint=False):
        model = super().build_model(cfg, from_checkpoint)
        # build a greedy decoder for validation with WER
        if self.cfg.criterion_name == "transducer_loss":  # a transducer model
            from espresso.tools.transducer_greedy_decoder import TransducerGreedyDecoder

            self.decoder_for_validation = TransducerGreedyDecoder(
                [model],
                self.target_dictionary,
                max_num_expansions_per_step=self.cfg.max_num_expansions_per_step,
            )
        elif self.cfg.criterion_name == "ctc":  # a ctc model
            raise NotImplementedError
        else:  # assume it is an attention-based encoder-decoder model
            from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder

            self.decoder_for_validation = SimpleGreedyDecoder(
                [model],
                self.target_dictionary,
                for_validation=True,
            )

        return model