コード例 #1
0
 def aggregate_logging_outputs(logging_outputs):
     """Aggregate logging outputs from data parallel training."""
     agg_output = LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(logging_outputs)
     word_error = sum(log.get('word_error', 0) for log in logging_outputs)
     word_count = sum(log.get('word_count', 0) for log in logging_outputs)
     char_error = sum(log.get('char_error', 0) for log in logging_outputs)
     char_count = sum(log.get('char_count', 0) for log in logging_outputs)
     if word_count > 0:  # model.training == False
         agg_output['word_error'] = word_error
         agg_output['word_count'] = word_count
     if char_count > 0:  # model.training == False
         agg_output['char_error'] = char_error
         agg_output['char_count'] = char_count
     return agg_output
コード例 #2
0
ファイル: fairseq_criterion.py プロジェクト: xtinkt/editable
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        xent_outputs_dict = LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(logging_outputs)

        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        if 'editability_loss' not in logging_outputs[0]:
            return xent_outputs_dict

        xent_outputs_dict['editability_loss'] = sum(log['editability_loss'] for log in logging_outputs) / len(
            logging_outputs)
        xent_outputs_dict['main_loss'] = sum(
            log.get('main_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.
        xent_outputs_dict['stability_loss'] = sum(log['stability_loss'] for log in logging_outputs) / len(
            logging_outputs)
        xent_outputs_dict['edit_complexity'] = sum(log['edit_complexity'] for log in logging_outputs) / len(
            logging_outputs)

        return xent_outputs_dict
コード例 #3
0
 def aggregate_logging_outputs(cls, logging_outputs):
     """Aggregate logging outputs from data parallel training."""
     return LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(
         logging_outputs)