def add_args(parser): CtcCriterion.add_args(parser) parser.add_argument('--ctc-encoder-layer', default=6, type=int, metavar='LAYER_NUM', help='The encoder layer whose feature are used to compute the CTC loss') parser.add_argument('--ctc-weight', default=1.0, type=float, metavar='W', help='The relative weight to assign to the CTC loss') parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, help='underlying criterion to use for the model output loss')
def __init__(self, task, cfg=UnispeechCriterionConfig): super().__init__(task) log_keys = [] if cfg.log_keys is None else cfg.log_keys self.mtlalpha = cfg.mtlalpha self.w2v_criterion = Wav2vecCriterion(task, cfg.infonce, cfg.loss_weights, log_keys) if self.mtlalpha > 0: self.ctc_criterion = CtcCriterion(cfg, task)
def test_loss(fsq_model, example_wav, attention_mask, target): from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data") task = AudioPretrainingTask.setup_task(audio_cfg) ctc = CtcCriterion(CtcCriterionConfig(), task) # fsq_model.train() labels_dict = processor.tokenizer(target, padding="longest", return_tensors="pt") labels = labels_dict.input_ids target_lengths = labels_dict.attention_mask.sum(-1) sample = { "net_input": { "source": example_wav, "padding_mask": attention_mask.ne(1), }, "target": labels, "target_lengths": target_lengths, "id": torch.zeros((1,)), } loss, _, _ = ctc(fsq_model, sample) print("Loss", loss)
def __init__(self, multitask_tasks): self.multitask_criterion = {} self.multitask_loss_weight = {} for task_name, task_obj in multitask_tasks.items(): if task_obj.args.decoder_type == "ctc": self.multitask_criterion[task_name] = CtcCriterion( task_obj.args.criterion_cfg, task_obj) else: self.multitask_criterion[ task_name] = LabelSmoothedCrossEntropyCriterion( task_obj, task_obj.args.criterion_cfg.sentence_avg, label_smoothing=task_obj.args.criterion_cfg. label_smoothing, )
def __init__( self, task, sentence_avg, label_smoothing, report_accuracy=False, zero_infinity=False, post_process=None, ): super().__init__(task) self.xent = SpeechTextPreTrainCrossEntCriterion( task, sentence_avg, label_smoothing, report_accuracy) cfg_dict = { "zero_infinity": zero_infinity, "sentence_avg": sentence_avg, "post_process": post_process, } cfg_ctc = CtcCriterionConfig(**cfg_dict) self.ctc = CtcCriterion(cfg_ctc, task)
def __init__(self, task, cfg=CifGPT2CriterionConfig): super().__init__(task) self.ctc_criterion = CtcCriterion(cfg, task) self.padding_idx = task.target_dictionary.pad() self.cfg = cfg self.task = task
class CIF_GPT2_Criterion(FairseqCriterion): def __init__(self, task, cfg=CifGPT2CriterionConfig): super().__init__(task) self.ctc_criterion = CtcCriterion(cfg, task) self.padding_idx = task.target_dictionary.pad() self.cfg = cfg self.task = task def cif_loss(self, model, sample, net_output, reduce): target = sample["target"] # N, T -> N * T target = target.view(-1) lprobs = model.get_normalized_probs_cif(net_output, log_probs=True) # N, T, D -> N * T, D lprobs = lprobs.view(-1, lprobs.size(-1)) ce_loss, _ = label_smoothed_nll_loss( lprobs, target.long(), 0.1, ignore_index=self.padding_idx, reduce=reduce, ) return ce_loss, lprobs def quantity_loss(self, sample, net_output): _number = net_output["num_output"] number = sample["target_lengths"].float() diff = torch.sqrt(torch.pow(_number - number, 2) + 1e-6).sum() qua_loss = diff return qua_loss def embedding_loss(self, net_output): cif_embs = net_output["cif_embeds"] target_embs = net_output["targets_embs"] pair_dist = F.pairwise_distance( cif_embs.view(-1, cif_embs.size(-1)), target_embs.view(-1, target_embs.size(-1))).sum() return pair_dist def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample) ctc_loss, sample_size, ctc_logging_output = self.ctc_criterion.get_loss( model, sample, net_output, reduce) if model.training: cif_loss, lprobs = self.cif_loss(model, sample, net_output, reduce) qua_loss = self.quantity_loss(sample, net_output) embedding_loss = self.embedding_loss(net_output) loss = self.cfg.lambda_ctc * ctc_loss + self.cfg.lambda_cif * cif_loss + self.cfg.lambda_qua * qua_loss + self.cfg.lambda_emb * embedding_loss else: loss = cif_loss = qua_loss = ctc_loss = embedding_loss = 0 mask = sample["target"] != self.padding_idx logging_output = { 'loss': loss, 'cif_loss': cif_loss, 'ctc_loss': ctc_loss, 'qua_loss': qua_loss, 'emb_loss': embedding_loss, 'ntokens': ctc_logging_output.get('ntokens', 0), 'sample_size': sample_size, 'nsentences': ctc_logging_output.get('nsentences', 0), 'ctc': ctc_logging_output } if not model.training: import editdistance num_output = torch.round(net_output["num_output"]).int( ) #sample["target_lengths"].int() # sample_size = 0.0 logging_output['sample_size'] = sample_size c_err = 0 c_len = 0 w_errs = 0 w_len = 0 with torch.no_grad(): logits = net_output['logits'] targets = sample["target"] for _logits, _targets, _num_out in zip(logits, targets, num_output): print(_logits.size(), _targets.size(), _num_out, _targets) p = _targets != self.task.target_dictionary.pad() decoded = _logits.argmax(dim=-1)[:_num_out] target = _targets[p] targ_units_arr = target.tolist() pred_units_arr = decoded.tolist() c_err += editdistance.eval(pred_units_arr, targ_units_arr) c_len += len(targ_units_arr) pred_w = self.task.tokenizer.decode(pred_units_arr) target_w = self.task.tokenizer.decode(targ_units_arr) dist = editdistance.eval(pred_w, target_w) w_errs += dist w_len += len(target_w.split()) logging_output["w_errors"] = w_errs logging_output["w_total"] = w_len logging_output["c_errors"] = c_err logging_output["c_total"] = c_len return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = utils.item( sum(log.get("loss", 0) for log in logging_outputs)) ctc_loss = utils.item( sum(log.get("ctc_loss", 0) for log in logging_outputs)) ce_loss = utils.item( sum(log.get("cif_loss", 0) for log in logging_outputs)) qua_loss = utils.item( sum(log.get("qua_loss", 0) for log in logging_outputs)) emb_loss = utils.item( sum(log.get("emb_loss", 0) for log in logging_outputs)) ntokens = utils.item( sum(log.get("ntokens", 0) for log in logging_outputs)) sample_size = utils.item( sum(log.get("sample_size", 0) for log in logging_outputs)) nsentences = utils.item( sum(log.get("nsentences", 0) for log in logging_outputs)) if sample_size > 0: # training metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) metrics.log_scalar("ctc_loss", ctc_loss / ntokens / math.log(2), ntokens, round=3) metrics.log_scalar("cif_loss", ce_loss / ntokens / math.log(2), ntokens, round=3) metrics.log_scalar("qua_loss", qua_loss / nsentences / math.log(2), nsentences, round=3) metrics.log_scalar("emb_loss", emb_loss / ntokens / math.log(2), ntokens, round=3) else: ctc_c_errors = sum(log['ctc'].get("c_errors", 0) for log in logging_outputs) metrics.log_scalar("ctc_c_errors", ctc_c_errors) ctc_c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs) metrics.log_scalar("ctc_c_total", ctc_c_total) if ctc_c_total > 0: metrics.log_derived( "ctc_uer", lambda meters: safe_round( meters["ctc_c_errors"].sum * 100.0 / meters[ "ctc_c_total"].sum, 3) if meters["ctc_c_total"].sum > 0 else float("nan"), ) c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) metrics.log_scalar("_c_errors", c_errors) c_total = sum(log.get("c_total", 0) for log in logging_outputs) metrics.log_scalar("_c_total", c_total) w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) metrics.log_scalar("_w_errors", w_errors) w_total = sum(log.get("w_total", 0) for log in logging_outputs) metrics.log_scalar("_w_total", w_total) if c_total > 0: metrics.log_derived( "uer", lambda meters: safe_round( meters["_c_errors"].sum * 100.0 / meters["_c_total"]. sum, 3) if meters["_c_total"].sum > 0 else float("nan"), ) if w_total > 0: metrics.log_derived( "wer", lambda meters: safe_round( meters["_w_errors"].sum * 100.0 / meters["_w_total"]. sum, 3) if meters["_w_total"].sum > 0 else float("nan"), ) #@staticmethod #def logging_outputs_can_be_summed() -> bool: def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ # XXX: Gather based reduction not implemented for xla yet. # So we fall to sum based reduction for xla. return False
class UnispeechCriterion(FairseqCriterion): def __init__(self, task, cfg=UnispeechCriterionConfig): super().__init__(task) log_keys = [] if cfg.log_keys is None else cfg.log_keys self.mtlalpha = cfg.mtlalpha self.w2v_criterion = Wav2vecCriterion(task, cfg.infonce, cfg.loss_weights, log_keys) if self.mtlalpha > 0: self.ctc_criterion = CtcCriterion(cfg, task) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ net_output = model(**sample["net_input"]) if self.mtlalpha > 0.0 and 'target' in sample: ctc_loss, ctc_sample_size, ctc_logging_output = self.ctc_criterion.get_loss( model, sample, net_output, reduce) else: ctc_loss = 0 ctc_sample_size = 0 ctc_logging_output = {} infonce_loss, infonce_sample_size, infonce_logging_output = self.w2v_criterion.get_loss( model.w2v_encoder.w2v_model, sample, net_output['contrastive_res'], reduce) loss = self.mtlalpha * ctc_loss + (1.0 - self.mtlalpha) * infonce_loss sample_size = infonce_sample_size logging_output = { 'loss': loss, 'ntokens': ctc_logging_output.get('ntokens', 0), 'nsentences': ctc_logging_output.get('nsentences', 0), 'ctc': ctc_logging_output, 'infonce': infonce_logging_output } return loss, sample_size, logging_output @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = utils.item( sum(log.get("loss", 0) for log in logging_outputs)) ctc_loss_sum = utils.item( sum(log['ctc'].get('loss', 0) for log in logging_outputs)) ctc_sample_size = utils.item( sum(log['ctc'].get('sample_size', 0) for log in logging_outputs)) ctc_ntokens = utils.item( sum(log['ctc'].get('ntokens', 0) for log in logging_outputs)) ctc_nsentences = utils.item( sum(log['ctc'].get('nsentences', 0) for log in logging_outputs)) ctras_loss_sum = utils.item( sum(log['infonce'].get('loss', 0) for log in logging_outputs)) ctras_sample_size = utils.item( sum(log['infonce'].get('sample_size', 0) for log in logging_outputs)) ctras_ntokens = utils.item( sum(log['infonce'].get('ntokens', 0) for log in logging_outputs)) ctras_nsentences = utils.item( sum(log['infonce'].get('nsentences', 0) for log in logging_outputs)) metrics.log_scalar("loss", loss_sum, 1, round=3) metrics.log_scalar("contrastive_loss", ctras_loss_sum / ctras_sample_size / math.log(2), ctras_sample_size, round=3) if ctc_sample_size == 0: metrics.log_scalar("ctc_loss", 0, ctc_sample_size, round=3) else: metrics.log_scalar("ctc_loss", ctc_loss_sum / ctc_sample_size / math.log(2), ctc_sample_size, round=3) if ctc_sample_size != ctc_ntokens: metrics.log_scalar("nll_loss", ctc_loss_sum / ctc_ntokens / math.log(2), ctc_ntokens, round=3) c_errors = sum(log['ctc'].get("c_errors", 0) for log in logging_outputs) metrics.log_scalar("_c_errors", c_errors) c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs) metrics.log_scalar("_c_total", c_total) w_errors = sum(log['ctc'].get("w_errors", 0) for log in logging_outputs) metrics.log_scalar("_w_errors", w_errors) wv_errors = sum(log['ctc'].get("wv_errors", 0) for log in logging_outputs) metrics.log_scalar("_wv_errors", wv_errors) w_total = sum(log['ctc'].get("w_total", 0) for log in logging_outputs) metrics.log_scalar("_w_total", w_total) if c_total > 0: metrics.log_derived( "uer", lambda meters: safe_round( meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 ) if meters["_c_total"].sum > 0 else float("nan"), ) if w_total > 0: metrics.log_derived( "wer", lambda meters: safe_round( meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 ) if meters["_w_total"].sum > 0 else float("nan"), ) metrics.log_derived( "raw_wer", lambda meters: safe_round( meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3) if meters["_w_total"].sum > 0 else float("nan"), ) metrics.log_scalar("nsentences", ctras_nsentences) metrics.log_scalar("ctc_sample_size", ctc_sample_size) metrics.log_scalar("contrastive_sample_size", ctras_sample_size) correct = sum(log['infonce'].get("correct", 0) for log in logging_outputs) metrics.log_scalar("_correct", correct) total = sum(log['infonce'].get("count", 0) for log in logging_outputs) metrics.log_scalar("_total", total) if total > 0: metrics.log_derived( "accuracy", lambda meters: safe_round( meters["_correct"].sum / meters["_total"].sum, 5) if meters["_total"].sum > 0 else float("nan"), ) builtin_keys = { 'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count' } for k in logging_outputs[0]['infonce']: if k not in builtin_keys: val = sum(log['infonce'].get(k, 0) for log in logging_outputs) / len(logging_outputs) if k.startswith('loss'): metrics.log_scalar(k, val / ctras_sample_size / math.log(2), ctras_sample_size) else: metrics.log_scalar(k, val, round=3) # FIXME: revert when gather based xla reduction is implemented #@staticmethod #def logging_outputs_can_be_summed() -> bool: def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ # XXX: Gather based reduction not implemented for xla yet. # So we fall to sum based reduction for xla. return False