Пример #1
0
    def __init__(self, args, task):
        super().__init__(args, task)

        dictionary = task.target_dictionary
        self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter)
        self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True)
        self.num_updates = -1
        self.epoch = 0
    def __init__(self, args, task):
        super().__init__(args, task)

        dictionary = task.target_dictionary
        self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter)
        self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True)
        self.num_updates = -1
        self.epoch = 0
        self.unigram_tensor = None
        if args.smoothing_type == 'unigram':
            self.unigram_tensor = torch.cuda.FloatTensor(dictionary.count).unsqueeze(-1) \
                if torch.cuda.is_available() and not args.cpu \
                else torch.FloatTensor(dictionary.count).unsqueeze(-1)
            self.unigram_tensor += args.unigram_pseudo_count  # for further backoff
            self.unigram_tensor.div_(self.unigram_tensor.sum())
Пример #3
0
 def build_model(self, args):
     model = super().build_model(args)
     # build the greedy decoder for validation with WER
     from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder
     self.decoder_for_validation = SimpleGreedyDecoder(
         [model], self.target_dictionary, for_validation=True,
     )
     return model
Пример #4
0
    def build_model(self, cfg: DictConfig, from_checkpoint=False):
        model = super().build_model(cfg, from_checkpoint)
        # build a greedy decoder for validation with WER
        if self.cfg.criterion_name == "transducer_loss":  # a transducer model
            from espresso.tools.transducer_greedy_decoder import TransducerGreedyDecoder

            self.decoder_for_validation = TransducerGreedyDecoder(
                [model],
                self.target_dictionary,
                max_num_expansions_per_step=self.cfg.max_num_expansions_per_step,
            )
        elif self.cfg.criterion_name == "ctc":  # a ctc model
            raise NotImplementedError
        else:  # assume it is an attention-based encoder-decoder model
            from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder

            self.decoder_for_validation = SimpleGreedyDecoder(
                [model],
                self.target_dictionary,
                for_validation=True,
            )

        return model
Пример #5
0
class CrossEntropyWithWERCriterion(CrossEntropyCriterion):

    def __init__(self, args, task):
        super().__init__(args, task)

        dictionary = task.target_dictionary
        self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter)
        self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True)
        self.num_updates = -1
        self.epoch = 0

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--print-training-sample-interval', type=int,
                            metavar='N', dest='print_interval', default=500,
                            help='print a training sample (reference + '
                                 'prediction) every this number of updates')
        # fmt: on

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample; periodically print out
        randomly sampled predictions if model is in training mode, otherwise
        aggregate word error stats for validation.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        dictionary = self.scorer.dictionary
        if model.training:
            net_output = model(**sample['net_input'], epoch=self.epoch)
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            target = model.get_targets(sample, net_output)
            if (
                self.num_updates // self.args.print_interval >
                (self.num_updates - 1) // self.args.print_interval
            ):  # print a randomly sampled result every print_interval updates
                pred = lprobs.argmax(-1).cpu()  # bsz x len
                assert pred.size() == target.size()
                with data_utils.numpy_seed(self.num_updates):
                    i = np.random.randint(0, len(sample['id']))
                ref_tokens = sample['target_raw_text'][i]
                length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
                ref_one = dictionary.tokens_to_sentence(
                    ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe,
                )
                pred_one = dictionary.tokens_to_sentence(
                    dictionary.string(pred.data[i][:length]), use_unk_sym=True,
                    bpe_symbol=self.args.remove_bpe,
                )
                print('| sample REF: ' + ref_one)
                print('| sample PRD: ' + pred_one)
        else:
            tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample)
            pred = tokens[:, 1:].data.cpu()  # bsz x len
            target = sample['target']
            # compute word error stats
            assert pred.size(0) == target.size(0)
            self.scorer.reset()
            for i in range(target.size(0)):
                utt_id = sample['utt_id'][i]
                ref_tokens = sample['target_raw_text'][i]
                pred_tokens = dictionary.string(pred.data[i])
                self.scorer.add_evaluation(
                    utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe,
                )

        lprobs = lprobs.view(-1, lprobs.size(-1))
        loss = F.nll_loss(
            lprobs,
            target.view(-1),
            ignore_index=self.padding_idx,
            reduction='sum' if reduce else 'none',
        )
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        if not model.training:  # do not compute word error in training mode
            logging_output['word_error'] = self.scorer.tot_word_error()
            logging_output['word_count'] = self.scorer.tot_word_count()
            logging_output['char_error'] = self.scorer.tot_char_error()
            logging_output['char_count'] = self.scorer.tot_char_count()
        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        CrossEntropyCriterion.reduce_metrics(logging_outputs)

        word_error = sum(log.get('word_error', 0) for log in logging_outputs)
        word_count = sum(log.get('word_count', 0) for log in logging_outputs)
        char_error = sum(log.get('char_error', 0) for log in logging_outputs)
        char_count = sum(log.get('char_count', 0) for log in logging_outputs)
        if word_count > 0:  # model.training == False
            metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4)
        if char_count > 0:  # model.training == False
            metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4)

    def set_num_updates(self, num_updates):
        self.num_updates = num_updates

    def set_epoch(self, epoch):
        self.epoch = epoch
class LabelSmoothedCrossEntropyWithWERCriterion(LabelSmoothedCrossEntropyCriterion):

    def __init__(self, args, task):
        super().__init__(args, task)

        dictionary = task.target_dictionary
        self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter)
        self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True)
        self.num_updates = -1
        self.epoch = 0
        self.unigram_tensor = None
        if args.smoothing_type == 'unigram':
            self.unigram_tensor = torch.cuda.FloatTensor(dictionary.count).unsqueeze(-1) \
                if torch.cuda.is_available() and not args.cpu \
                else torch.FloatTensor(dictionary.count).unsqueeze(-1)
            self.unigram_tensor += args.unigram_pseudo_count  # for further backoff
            self.unigram_tensor.div_(self.unigram_tensor.sum())

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        LabelSmoothedCrossEntropyCriterion.add_args(parser)
        parser.add_argument('--print-training-sample-interval', type=int,
                            metavar='N', dest='print_interval', default=500,
                            help='print a training sample (reference + '
                                 'prediction) every this number of updates')
        parser.add_argument('--smoothing-type', type=str, default='uniform',
                            choices=['uniform', 'unigram', 'temporal'],
                            help='label smoothing type. Default: uniform')
        parser.add_argument('--unigram-pseudo-count', type=float, default=1.0,
                            metavar='C', help='pseudo count for unigram label '
                            'smoothing. Only relevant if --smoothing-type=unigram')
        # fmt: on

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample; periodically print out
        randomly sampled predictions if model is in training mode, otherwise
        aggregate word error stats for validation.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        dictionary = self.scorer.dictionary
        if model.training:
            net_output = model(**sample['net_input'], epoch=self.epoch)
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            target = model.get_targets(sample, net_output)
            if (
                self.num_updates // self.args.print_interval >
                (self.num_updates - 1) // self.args.print_interval
            ):  # print a randomly sampled result every print_interval updates
                pred = lprobs.argmax(-1).cpu()  # bsz x len
                assert pred.size() == target.size()
                with data_utils.numpy_seed(self.num_updates):
                    i = np.random.randint(0, len(sample['id']))
                ref_tokens = sample['target_raw_text'][i]
                length = utils.strip_pad(target.data[i], self.padding_idx).size(0)
                ref_one = dictionary.tokens_to_sentence(
                    ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe,
                )
                pred_one = dictionary.tokens_to_sentence(
                    dictionary.string(pred.data[i][:length]), use_unk_sym=True,
                    bpe_symbol=self.args.remove_bpe,
                )
                print('| sample REF: ' + ref_one)
                print('| sample PRD: ' + pred_one)
        else:
            tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample)
            pred = tokens[:, 1:].data.cpu()  # bsz x len
            target = sample['target']
            # compute word error stats
            assert pred.size(0) == target.size(0)
            self.scorer.reset()
            for i in range(target.size(0)):
                utt_id = sample['utt_id'][i]
                ref_tokens = sample['target_raw_text'][i]
                pred_tokens = dictionary.string(pred.data[i])
                self.scorer.add_evaluation(
                    utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe,
                )

        prob_mask = temporal_label_smoothing_prob_mask(
            lprobs, target, padding_index=self.padding_idx,
        ) if self.args.smoothing_type == 'temporal' else None

        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = target.view(-1, 1)
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
            smoothing_type=self.args.smoothing_type, prob_mask=prob_mask,
            unigram_tensor=self.unigram_tensor,
        )
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        if not model.training:  # do not compute word error in training mode
            logging_output['word_error'] = self.scorer.tot_word_error()
            logging_output['word_count'] = self.scorer.tot_word_count()
            logging_output['char_error'] = self.scorer.tot_char_error()
            logging_output['char_count'] = self.scorer.tot_char_count()
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        agg_output = LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(logging_outputs)
        word_error = sum(log.get('word_error', 0) for log in logging_outputs)
        word_count = sum(log.get('word_count', 0) for log in logging_outputs)
        char_error = sum(log.get('char_error', 0) for log in logging_outputs)
        char_count = sum(log.get('char_count', 0) for log in logging_outputs)
        if word_count > 0:  # model.training == False
            agg_output['word_error'] = word_error
            agg_output['word_count'] = word_count
        if char_count > 0:  # model.training == False
            agg_output['char_error'] = char_error
            agg_output['char_count'] = char_count
        return agg_output

    def set_num_updates(self, num_updates):
        self.num_updates = num_updates

    def set_epoch(self, epoch):
        self.epoch = epoch