コード例 #1
0
ファイル: ctc_multi_loss.py プロジェクト: sarapapi/fairseq
 def add_args(parser):
     CtcCriterion.add_args(parser)
     parser.add_argument('--ctc-encoder-layer', default=6, type=int, metavar='LAYER_NUM',
                         help='The encoder layer whose feature are used to compute the CTC loss')
     parser.add_argument('--ctc-weight', default=1.0, type=float, metavar='W',
                         help='The relative weight to assign to the CTC loss')
     parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
                         help='underlying criterion to use for the model output loss')
コード例 #2
0
 def __init__(self, task, cfg=UnispeechCriterionConfig):
     super().__init__(task)
     log_keys = [] if cfg.log_keys is None else cfg.log_keys
     self.mtlalpha = cfg.mtlalpha
     self.w2v_criterion = Wav2vecCriterion(task, cfg.infonce,
                                           cfg.loss_weights, log_keys)
     if self.mtlalpha > 0:
         self.ctc_criterion = CtcCriterion(cfg, task)
コード例 #3
0
def test_loss(fsq_model, example_wav, attention_mask, target):
    from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig
    from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
    audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
    task = AudioPretrainingTask.setup_task(audio_cfg)
    ctc = CtcCriterion(CtcCriterionConfig(), task)
#    fsq_model.train()

    labels_dict = processor.tokenizer(target, padding="longest", return_tensors="pt")
    labels = labels_dict.input_ids
    target_lengths = labels_dict.attention_mask.sum(-1)

    sample = {
        "net_input": {
            "source": example_wav,
            "padding_mask": attention_mask.ne(1),
        },
        "target": labels,
        "target_lengths": target_lengths,
        "id": torch.zeros((1,)),
    }

    loss, _, _ = ctc(fsq_model, sample)

    print("Loss", loss)
コード例 #4
0
 def __init__(self, multitask_tasks):
     self.multitask_criterion = {}
     self.multitask_loss_weight = {}
     for task_name, task_obj in multitask_tasks.items():
         if task_obj.args.decoder_type == "ctc":
             self.multitask_criterion[task_name] = CtcCriterion(
                 task_obj.args.criterion_cfg, task_obj)
         else:
             self.multitask_criterion[
                 task_name] = LabelSmoothedCrossEntropyCriterion(
                     task_obj,
                     task_obj.args.criterion_cfg.sentence_avg,
                     label_smoothing=task_obj.args.criterion_cfg.
                     label_smoothing,
                 )
コード例 #5
0
 def __init__(
     self,
     task,
     sentence_avg,
     label_smoothing,
     report_accuracy=False,
     zero_infinity=False,
     post_process=None,
 ):
     super().__init__(task)
     self.xent = SpeechTextPreTrainCrossEntCriterion(
         task, sentence_avg, label_smoothing, report_accuracy)
     cfg_dict = {
         "zero_infinity": zero_infinity,
         "sentence_avg": sentence_avg,
         "post_process": post_process,
     }
     cfg_ctc = CtcCriterionConfig(**cfg_dict)
     self.ctc = CtcCriterion(cfg_ctc, task)
コード例 #6
0
 def __init__(self, task, cfg=CifGPT2CriterionConfig):
     super().__init__(task)
     self.ctc_criterion = CtcCriterion(cfg, task)
     self.padding_idx = task.target_dictionary.pad()
     self.cfg = cfg
     self.task = task
コード例 #7
0
class CIF_GPT2_Criterion(FairseqCriterion):
    def __init__(self, task, cfg=CifGPT2CriterionConfig):
        super().__init__(task)
        self.ctc_criterion = CtcCriterion(cfg, task)
        self.padding_idx = task.target_dictionary.pad()
        self.cfg = cfg
        self.task = task

    def cif_loss(self, model, sample, net_output, reduce):
        target = sample["target"]
        # N, T -> N * T
        target = target.view(-1)
        lprobs = model.get_normalized_probs_cif(net_output, log_probs=True)

        # N, T, D -> N * T, D
        lprobs = lprobs.view(-1, lprobs.size(-1))
        ce_loss, _ = label_smoothed_nll_loss(
            lprobs,
            target.long(),
            0.1,
            ignore_index=self.padding_idx,
            reduce=reduce,
        )
        return ce_loss, lprobs

    def quantity_loss(self, sample, net_output):
        _number = net_output["num_output"]
        number = sample["target_lengths"].float()
        diff = torch.sqrt(torch.pow(_number - number, 2) + 1e-6).sum()
        qua_loss = diff
        return qua_loss

    def embedding_loss(self, net_output):
        cif_embs = net_output["cif_embeds"]
        target_embs = net_output["targets_embs"]
        pair_dist = F.pairwise_distance(
            cif_embs.view(-1, cif_embs.size(-1)),
            target_embs.view(-1, target_embs.size(-1))).sum()
        return pair_dist

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        net_output = model(**sample)

        ctc_loss, sample_size, ctc_logging_output = self.ctc_criterion.get_loss(
            model, sample, net_output, reduce)
        if model.training:
            cif_loss, lprobs = self.cif_loss(model, sample, net_output, reduce)
            qua_loss = self.quantity_loss(sample, net_output)
            embedding_loss = self.embedding_loss(net_output)
            loss = self.cfg.lambda_ctc * ctc_loss + self.cfg.lambda_cif * cif_loss + self.cfg.lambda_qua * qua_loss + self.cfg.lambda_emb * embedding_loss
        else:
            loss = cif_loss = qua_loss = ctc_loss = embedding_loss = 0

        mask = sample["target"] != self.padding_idx

        logging_output = {
            'loss': loss,
            'cif_loss': cif_loss,
            'ctc_loss': ctc_loss,
            'qua_loss': qua_loss,
            'emb_loss': embedding_loss,
            'ntokens': ctc_logging_output.get('ntokens', 0),
            'sample_size': sample_size,
            'nsentences': ctc_logging_output.get('nsentences', 0),
            'ctc': ctc_logging_output
        }

        if not model.training:
            import editdistance
            num_output = torch.round(net_output["num_output"]).int(
            )  #sample["target_lengths"].int() #
            sample_size = 0.0
            logging_output['sample_size'] = sample_size
            c_err = 0
            c_len = 0
            w_errs = 0
            w_len = 0
            with torch.no_grad():
                logits = net_output['logits']
                targets = sample["target"]
                for _logits, _targets, _num_out in zip(logits, targets,
                                                       num_output):
                    print(_logits.size(), _targets.size(), _num_out, _targets)
                    p = _targets != self.task.target_dictionary.pad()
                    decoded = _logits.argmax(dim=-1)[:_num_out]
                    target = _targets[p]

                    targ_units_arr = target.tolist()
                    pred_units_arr = decoded.tolist()
                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)
                    pred_w = self.task.tokenizer.decode(pred_units_arr)
                    target_w = self.task.tokenizer.decode(targ_units_arr)
                    dist = editdistance.eval(pred_w, target_w)
                    w_errs += dist
                    w_len += len(target_w.split())

                logging_output["w_errors"] = w_errs
                logging_output["w_total"] = w_len
                logging_output["c_errors"] = c_err
                logging_output["c_total"] = c_len

        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))
        ctc_loss = utils.item(
            sum(log.get("ctc_loss", 0) for log in logging_outputs))
        ce_loss = utils.item(
            sum(log.get("cif_loss", 0) for log in logging_outputs))
        qua_loss = utils.item(
            sum(log.get("qua_loss", 0) for log in logging_outputs))
        emb_loss = utils.item(
            sum(log.get("emb_loss", 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get("ntokens", 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs))

        if sample_size > 0:  # training
            metrics.log_scalar("loss",
                               loss_sum / sample_size / math.log(2),
                               sample_size,
                               round=3)
            metrics.log_scalar("ctc_loss",
                               ctc_loss / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_scalar("cif_loss",
                               ce_loss / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_scalar("qua_loss",
                               qua_loss / nsentences / math.log(2),
                               nsentences,
                               round=3)
            metrics.log_scalar("emb_loss",
                               emb_loss / ntokens / math.log(2),
                               ntokens,
                               round=3)

        else:
            ctc_c_errors = sum(log['ctc'].get("c_errors", 0)
                               for log in logging_outputs)
            metrics.log_scalar("ctc_c_errors", ctc_c_errors)
            ctc_c_total = sum(log['ctc'].get("c_total", 0)
                              for log in logging_outputs)
            metrics.log_scalar("ctc_c_total", ctc_c_total)
            if ctc_c_total > 0:
                metrics.log_derived(
                    "ctc_uer",
                    lambda meters: safe_round(
                        meters["ctc_c_errors"].sum * 100.0 / meters[
                            "ctc_c_total"].sum, 3)
                    if meters["ctc_c_total"].sum > 0 else float("nan"),
                )

            c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
            metrics.log_scalar("_c_errors", c_errors)
            c_total = sum(log.get("c_total", 0) for log in logging_outputs)
            metrics.log_scalar("_c_total", c_total)
            w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
            metrics.log_scalar("_w_errors", w_errors)
            w_total = sum(log.get("w_total", 0) for log in logging_outputs)
            metrics.log_scalar("_w_total", w_total)
            if c_total > 0:
                metrics.log_derived(
                    "uer",
                    lambda meters: safe_round(
                        meters["_c_errors"].sum * 100.0 / meters["_c_total"].
                        sum, 3)
                    if meters["_c_total"].sum > 0 else float("nan"),
                )
            if w_total > 0:
                metrics.log_derived(
                    "wer",
                    lambda meters: safe_round(
                        meters["_w_errors"].sum * 100.0 / meters["_w_total"].
                        sum, 3)
                    if meters["_w_total"].sum > 0 else float("nan"),
                )

    #@staticmethod
    #def logging_outputs_can_be_summed() -> bool:
    def logging_outputs_can_be_summed(self) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        # XXX: Gather based reduction not implemented for xla yet.
        # So we fall to sum based reduction for xla.
        return False
コード例 #8
0
class UnispeechCriterion(FairseqCriterion):
    def __init__(self, task, cfg=UnispeechCriterionConfig):
        super().__init__(task)
        log_keys = [] if cfg.log_keys is None else cfg.log_keys
        self.mtlalpha = cfg.mtlalpha
        self.w2v_criterion = Wav2vecCriterion(task, cfg.infonce,
                                              cfg.loss_weights, log_keys)
        if self.mtlalpha > 0:
            self.ctc_criterion = CtcCriterion(cfg, task)

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        net_output = model(**sample["net_input"])

        if self.mtlalpha > 0.0 and 'target' in sample:
            ctc_loss, ctc_sample_size, ctc_logging_output = self.ctc_criterion.get_loss(
                model, sample, net_output, reduce)
        else:
            ctc_loss = 0
            ctc_sample_size = 0
            ctc_logging_output = {}

        infonce_loss, infonce_sample_size, infonce_logging_output = self.w2v_criterion.get_loss(
            model.w2v_encoder.w2v_model, sample, net_output['contrastive_res'],
            reduce)

        loss = self.mtlalpha * ctc_loss + (1.0 - self.mtlalpha) * infonce_loss

        sample_size = infonce_sample_size
        logging_output = {
            'loss': loss,
            'ntokens': ctc_logging_output.get('ntokens', 0),
            'nsentences': ctc_logging_output.get('nsentences', 0),
            'ctc': ctc_logging_output,
            'infonce': infonce_logging_output
        }

        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))

        ctc_loss_sum = utils.item(
            sum(log['ctc'].get('loss', 0) for log in logging_outputs))
        ctc_sample_size = utils.item(
            sum(log['ctc'].get('sample_size', 0) for log in logging_outputs))
        ctc_ntokens = utils.item(
            sum(log['ctc'].get('ntokens', 0) for log in logging_outputs))
        ctc_nsentences = utils.item(
            sum(log['ctc'].get('nsentences', 0) for log in logging_outputs))

        ctras_loss_sum = utils.item(
            sum(log['infonce'].get('loss', 0) for log in logging_outputs))
        ctras_sample_size = utils.item(
            sum(log['infonce'].get('sample_size', 0)
                for log in logging_outputs))
        ctras_ntokens = utils.item(
            sum(log['infonce'].get('ntokens', 0) for log in logging_outputs))
        ctras_nsentences = utils.item(
            sum(log['infonce'].get('nsentences', 0)
                for log in logging_outputs))

        metrics.log_scalar("loss", loss_sum, 1, round=3)
        metrics.log_scalar("contrastive_loss",
                           ctras_loss_sum / ctras_sample_size / math.log(2),
                           ctras_sample_size,
                           round=3)

        if ctc_sample_size == 0:
            metrics.log_scalar("ctc_loss", 0, ctc_sample_size, round=3)
        else:
            metrics.log_scalar("ctc_loss",
                               ctc_loss_sum / ctc_sample_size / math.log(2),
                               ctc_sample_size,
                               round=3)

            if ctc_sample_size != ctc_ntokens:
                metrics.log_scalar("nll_loss",
                                   ctc_loss_sum / ctc_ntokens / math.log(2),
                                   ctc_ntokens,
                                   round=3)
        c_errors = sum(log['ctc'].get("c_errors", 0)
                       for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log['ctc'].get("w_errors", 0)
                       for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log['ctc'].get("wv_errors", 0)
                        for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log['ctc'].get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                ) if meters["_c_total"].sum > 0 else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                ) if meters["_w_total"].sum > 0 else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum,
                    3) if meters["_w_total"].sum > 0 else float("nan"),
            )

        metrics.log_scalar("nsentences", ctras_nsentences)
        metrics.log_scalar("ctc_sample_size", ctc_sample_size)
        metrics.log_scalar("contrastive_sample_size", ctras_sample_size)

        correct = sum(log['infonce'].get("correct", 0)
                      for log in logging_outputs)
        metrics.log_scalar("_correct", correct)

        total = sum(log['infonce'].get("count", 0) for log in logging_outputs)
        metrics.log_scalar("_total", total)

        if total > 0:
            metrics.log_derived(
                "accuracy",
                lambda meters: safe_round(
                    meters["_correct"].sum / meters["_total"].sum, 5)
                if meters["_total"].sum > 0 else float("nan"),
            )

        builtin_keys = {
            'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'
        }
        for k in logging_outputs[0]['infonce']:
            if k not in builtin_keys:
                val = sum(log['infonce'].get(k, 0)
                          for log in logging_outputs) / len(logging_outputs)
                if k.startswith('loss'):
                    metrics.log_scalar(k,
                                       val / ctras_sample_size / math.log(2),
                                       ctras_sample_size)
                else:
                    metrics.log_scalar(k, val, round=3)
                    # FIXME: revert when gather based xla reduction is implemented

    #@staticmethod
    #def logging_outputs_can_be_summed() -> bool:
    def logging_outputs_can_be_summed(self) -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        # XXX: Gather based reduction not implemented for xla yet.
        # So we fall to sum based reduction for xla.
        return False