示例#1
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        LabelSmoothedCrossEntropyCriterion.reduce_metrics(logging_outputs)

        mask_loss_sum = sum(log.get('mask_loss', 0) for log in logging_outputs)
        # mask_loss_final_sum = sum(log.get('mask_loss_final', 0) for log in logging_outputs)
        p_sum = sum(log.get('p2', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        mask_sum = sum(log.get('mask_ave', 0) for log in logging_outputs)

        metrics.log_scalar('mask_loss',
                           mask_loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=6)
        # metrics.log_scalar('mask_loss_final', mask_loss_final_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('p_2', p_sum / sample_size, sample_size, round=5)
        metrics.log_scalar('mask_ave',
                           mask_sum / sample_size,
                           sample_size,
                           round=3)
        metrics.log_scalar('new_weight',
                           logging_outputs[0].get("new_weight", 0) / 4,
                           len(logging_outputs),
                           round=3)
示例#2
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        LabelSmoothedCrossEntropyCriterion.reduce_metrics(logging_outputs)

        word_error = sum(log.get('word_error', 0) for log in logging_outputs)
        word_count = sum(log.get('word_count', 0) for log in logging_outputs)
        char_error = sum(log.get('char_error', 0) for log in logging_outputs)
        char_count = sum(log.get('char_count', 0) for log in logging_outputs)
        if word_count > 0:  # model.training == False
            metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4)
        if char_count > 0:  # model.training == False
            metrics.log_scalar('cer', float(char_error) / char_count * 100, char_count, round=4)