Exemplo n.º 1
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get('loss', 0) for log in logging_outputs))
        nll_loss_sum = utils.item(
            sum(log.get('nll_loss', 0) for log in logging_outputs))
        alignment_loss_sum = utils.item(
            sum(log.get('alignment_loss', 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get('ntokens', 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar('nll_loss',
                           nll_loss_sum / ntokens / math.log(2),
                           ntokens,
                           round=3)
        metrics.log_scalar('alignment_loss',
                           alignment_loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_derived(
            'ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
Exemplo n.º 2
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(sum(log.get('loss', 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
        nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs))
        sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('ntokens', ntokens)
        metrics.log_scalar('nsentences', nsentences)

        correct = sum(log.get("correct", 0) for log in logging_outputs)
        metrics.log_scalar("_correct", correct)

        total = sum(log.get("count", 0) for log in logging_outputs)
        metrics.log_scalar("_total", total)


        if total > 0:
            metrics.log_derived(
                "accuracy",
                lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5)
                if meters["_total"].sum > 0
                else float("nan"),
            )

        builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}

        for k in logging_outputs[0]:
            if k not in builtin_keys:
                val = sum(log.get(k, 0) for log in logging_outputs) / len(logging_outputs)
                if k.startswith('loss'):
                    metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
                else:
                    metrics.log_scalar(k, val, round=3)
Exemplo n.º 3
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model,
                                           net_output,
                                           sample,
                                           reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }

        alignment_loss = None

        # Compute alignment loss only for training set and non dummy batches.
        if 'alignments' in sample and sample['alignments'] is not None:
            alignment_loss = self.compute_alignment_loss(sample, net_output)

        if alignment_loss is not None:
            logging_output['alignment_loss'] = utils.item(alignment_loss.data)
            loss += self.alignment_lambda * alignment_loss

        return loss, sample_size, logging_output
Exemplo n.º 4
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))
        loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
        nll_loss = utils.item(
            sum(log.get("nll_loss", 0) for log in logging_outputs))

        metrics.log_scalar('loss',
                           loss / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar('nll_loss',
                           nll_loss / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_derived(
            'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))

        for key in logging_outputs[0]:
            if key[-5:] == "-loss":
                val = sum(log.get(key, 0) for log in logging_outputs)
                metrics.log_scalar(
                    key[:-5],
                    val / sample_size /
                    math.log(2) if sample_size > 0 else 0.0,
                    sample_size,
                    round=3,
                )
Exemplo n.º 5
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""

        loss_sum = utils.item(
            sum(log.get("loss", 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs))

        metrics.log_scalar("loss",
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar("ntokens", ntokens)
        metrics.log_scalar("nsentences", nsentences)
        if sample_size != ntokens:
            metrics.log_scalar("nll_loss",
                               loss_sum / ntokens / math.log(2),
                               ntokens,
                               round=3)

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                ) if meters["_c_total"].sum > 0 else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                ) if meters["_w_total"].sum > 0 else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum,
                    3) if meters["_w_total"].sum > 0 else float("nan"),
            )
Exemplo n.º 6
0
    def forward(self, model, sample, reduce=True):
        encoder_output, logits = model(**sample["net_input"])

        # (B, T, C) from the encoder
        ctc_logits = encoder_output["ctc_logits"]
        ctc_lprobs, lprobs = model.get_normalized_probs(ctc_logits,
                                                        logits,
                                                        log_probs=True)

        len_ctc_logits = (~encoder_output["padding_mask"]).long().sum(-1)
        target = sample["target"]
        target_lengths = sample["target_lengths"]

        ctc_loss = self.ctc_loss(ctc_lprobs, len_ctc_logits, target,
                                 target_lengths)
        ce_loss, correct, total = self.ce_loss(lprobs, target)
        loss = ctc_loss + ce_loss

        ntokens = (sample["ntokens"] if "ntokens" in sample else
                   sample["target_lengths"].sum().item())
        sample_size = ntokens
        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ctc_loss": utils.item(ctc_loss.data),
            "ce_loss": utils.item(ce_loss.data),
            "ntokens": ntokens,
            "correct": utils.item(correct.data),
            "total": utils.item(total.data),
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }

        if not model.training:

            ctc_lprobs = ctc_lprobs.float().cpu()

            c_err, c_len = self.ctc_greedy_eval(ctc_lprobs,
                                                len_ctc_logits,
                                                target,
                                                pad_idx=self.pad_idx,
                                                eos_idx=self.eos_idx,
                                                blk_idx=self.blk_idx)

            logging_output["c_errors"] = c_err
            logging_output["c_total"] = c_len

        return loss, sample_size, logging_output
Exemplo n.º 7
0
 def upgrade_state_dict(self, state_dict):
     if utils.item(state_dict.get('decoder.version', torch.Tensor(
         [1]))[0]) < 2:
         # old models use incorrect weight norm dimension
         for i, conv in enumerate(self.convolutions):
             # reconfigure weight norm
             nn.utils.remove_weight_norm(conv)
             self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
         state_dict['decoder.version'] = torch.Tensor([1])
     return state_dict
Exemplo n.º 8
0
    def get_logging_output(self, sample, target, lprobs, loss):
        target = target.view(-1)
        mask = target != self.padding_idx
        correct = torch.sum(
            lprobs.argmax(1).masked_select(mask) == target.masked_select(mask))
        total = torch.sum(mask)
        sample_size = (sample["target"].size(0)
                       if self.sentence_avg else sample["ntokens"])

        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
            "correct": utils.item(correct.data),
            "total": utils.item(total.data),
            "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
        }

        return sample_size, logging_output
Exemplo n.º 9
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        net_output = model(**sample["net_input"])
        emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
        B = emissions.size(0)
        T = emissions.size(1)
        device = emissions.device

        target = torch.IntTensor(B, T)
        target_size = torch.IntTensor(B)
        using_linseg = self.linseg_step()

        for b in range(B):
            initial_target_size = sample["target_lengths"][b].item()
            if initial_target_size == 0:
                raise ValueError("target size cannot be zero")

            tgt = sample["target"][b, :initial_target_size].tolist()
            tgt = self.replace_eos_with_silence(tgt)
            tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
            tgt = tgt[:T]

            if using_linseg:
                tgt = [tgt[t * len(tgt) // T] for t in range(T)]

            target[b][: len(tgt)] = torch.IntTensor(tgt)
            target_size[b] = len(tgt)

        loss = self.asg.forward(emissions, target.to(device), target_size.to(device))

        if reduce:
            loss = torch.sum(loss)

        sample_size = (
            sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output
Exemplo n.º 10
0
    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = "{}.embed_positions.weights".format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict["{}.embed_positions._float_tensor".format(
                name)] = torch.FloatTensor(1)

        if f"{name}.output_projection.weight" not in state_dict:
            if self.share_input_output_embed:
                embed_out_key = f"{name}.embed_tokens.weight"
            else:
                embed_out_key = f"{name}.embed_out"
            if embed_out_key in state_dict:
                state_dict[f"{name}.output_projection.weight"] = state_dict[
                    embed_out_key]
                if not self.share_input_output_embed:
                    del state_dict[embed_out_key]

        for i in range(self.num_layers):
            # update layer norms
            layer_norm_map = {
                "0": "self_attn_layer_norm",
                "1": "encoder_attn_layer_norm",
                "2": "final_layer_norm",
            }
            for old, new in layer_norm_map.items():
                for m in ("weight", "bias"):
                    k = "{}.layers.{}.layer_norms.{}.{}".format(
                        name, i, old, m)
                    if k in state_dict:
                        state_dict["{}.layers.{}.{}.{}".format(
                            name, i, new, m)] = state_dict[k]
                        del state_dict[k]

        version_key = "{}.version".format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])

        return state_dict
Exemplo n.º 11
0
    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = "{}.embed_positions.weights".format(name)
            if weights_key in state_dict:
                print("deleting {0}".format(weights_key))
                del state_dict[weights_key]
            state_dict["{}.embed_positions._float_tensor".format(
                name)] = torch.FloatTensor(1)
        for i in range(self.num_layers):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = "{}.version".format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Exemplo n.º 12
0
    def string(
        self,
        tensor,
        bpe_symbol=None,
        escape_unk=False,
        extra_symbols_to_ignore=None,
        unk_string=None,
    ):
        """Helper for converting a tensor of token indices to a string.

        Can optionally remove BPE symbols or escape <unk> words.
        """
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return "\n".join(
                self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
                for t in tensor
            )

        extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
        extra_symbols_to_ignore.add(self.eos())

        def token_string(i):
            if i == self.unk():
                if unk_string is not None:
                    return unk_string
                else:
                    return self.unk_string(escape_unk)
            else:
                return self[i]

        if hasattr(self, "bos_index"):
            extra_symbols_to_ignore.add(self.bos())

        sent = " ".join(
            token_string(i)
            for i in tensor
            if utils.item(i) not in extra_symbols_to_ignore
        )

        return data_utils.post_process(sent, bpe_symbol)
Exemplo n.º 13
0
            def forward(self, model, sample, reduce=True):
                net_outputs = model(**sample['net_input'])
                targets = sample['target']

                bsz = targets[0].size(0)
                loss = net_outputs[0][0].new(
                    1 if reduce else bsz).float().zero_()

                sample_size = 0
                logging_output = {}
                for o, t in zip(net_outputs[0], targets):
                    m = FakeModel(model, (o, net_outputs[1]), t)
                    sample['target'] = t
                    l, ss, logging_output = self.underlying_criterion(
                        m, sample, reduce)
                    loss += l
                    sample_size += ss

                loss.div_(len(targets))
                sample_size /= len(targets)

                logging_output['loss'] = utils.item(
                    loss.data) if reduce else loss.data
                return loss, sample_size, logging_output
Exemplo n.º 14
0
    def forward(self, model, sample, reduce=True):
        net_output = model(**sample["net_input"])
        lprobs = model.get_normalized_probs(
            net_output, log_probs=True
        ).contiguous()  # (T, B, C) from the encoder

        if "src_lengths" in sample["net_input"]:
            input_lengths = sample["net_input"]["src_lengths"]
        else:
            non_padding_mask = ~net_output["padding_mask"]
            input_lengths = non_padding_mask.long().sum(-1)

        pad_mask = (sample["target"] != self.pad_idx) & (
            sample["target"] != self.eos_idx
        )
        targets_flat = sample["target"].masked_select(pad_mask)
        target_lengths = sample["target_lengths"]

        with torch.backends.cudnn.flags(enabled=False):
            loss = F.ctc_loss(
                lprobs,
                targets_flat,
                input_lengths,
                target_lengths,
                blank=self.blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )

        ntokens = (
            sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
        )

        sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ntokens": ntokens,
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }

        if not model.training:
            import editdistance

            with torch.no_grad():
                lprobs_t = lprobs.transpose(0, 1).float().cpu()

                c_err = 0
                c_len = 0
                w_errs = 0
                w_len = 0
                wv_errs = 0
                for lp, t, inp_l in zip(
                    lprobs_t,
                    sample["target_label"]
                    if "target_label" in sample
                    else sample["target"],
                    input_lengths,
                ):
                    lp = lp[:inp_l].unsqueeze(0)

                    decoded = None
                    if self.w2l_decoder is not None:
                        decoded = self.w2l_decoder.decode(lp)
                        if len(decoded) < 1:
                            decoded = None
                        else:
                            decoded = decoded[0]
                            if len(decoded) < 1:
                                decoded = None
                            else:
                                decoded = decoded[0]

                    p = (t != self.task.dictionary.pad()) & (
                        t != self.task.dictionary.eos()
                    )
                    targ = t[p]
                    targ_units = self.task.dictionary.string(targ)
                    targ_units_arr = targ.tolist()

                    toks = lp.argmax(dim=-1).unique_consecutive()
                    pred_units_arr = toks[toks != self.blank_idx].tolist()

                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)

                    targ_words = post_process(targ_units, self.post_process).split()

                    pred_units = self.task.dictionary.string(pred_units_arr)
                    pred_words_raw = post_process(pred_units, self.post_process).split()

                    if decoded is not None and "words" in decoded:
                        pred_words = decoded["words"]
                        w_errs += editdistance.eval(pred_words, targ_words)
                        wv_errs += editdistance.eval(pred_words_raw, targ_words)
                    else:
                        dist = editdistance.eval(pred_words_raw, targ_words)
                        w_errs += dist
                        wv_errs += dist

                    w_len += len(targ_words)

                logging_output["wv_errors"] = wv_errs
                logging_output["w_errors"] = w_errs
                logging_output["w_total"] = w_len
                logging_output["c_errors"] = c_err
                logging_output["c_total"] = c_len

        return loss, sample_size, logging_output
Exemplo n.º 15
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        lm_logits, output_metadata = model(**sample["net_input"])

        # reshape lm_logits from (N,T,C) to (N*T,C)
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        lm_targets = sample['lm_target'].view(-1)
        lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets,
                                             self.padding_idx)

        # compute the number of tokens for which loss is computed. This is used
        # to normalize the loss
        ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
        loss = lm_loss / ntokens
        nsentences = sample['nsentences']
        # nsentences = 0

        # Compute sentence loss if masked_lm_only is False
        sentence_loss = None
        if not self.masked_lm_only:
            sentence_logits = output_metadata['sentence_logits']
            sentence_targets = sample['sentence_target'].view(-1)
            # This needs to be recomputed due to some differences between
            # TokenBlock and BlockPair dataset. This can be resolved with a
            # refactor of BERTModel which we will do in the future.
            # TODO: Remove this after refactor of BERTModel
            nsentences = sentence_targets.size(0)

            # Check for logits being none which can happen when remove_heads
            # is set to true in the BERT model. Ideally we should set
            # masked_lm_only to true in this case, but that requires some
            # refactor in the BERT model.
            if sentence_logits is not None:
                sentence_loss = compute_cross_entropy_loss(
                    sentence_logits, sentence_targets)

                loss += self.nsp_loss_weight * (sentence_loss / nsentences)

        # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            'loss':
            utils.item(loss.data) if reduce else loss.data,
            'lm_loss':
            utils.item(lm_loss.data) if reduce else lm_loss.data,
            # sentence loss is not always computed
            'sentence_loss':
            ((utils.item(sentence_loss.data) if reduce else sentence_loss.data)
             if sentence_loss is not None else 0.0),
            'ntokens':
            ntokens,
            'nsentences':
            nsentences,
            'sample_size':
            sample_size,
        }
        return loss, sample_size, logging_output
Exemplo n.º 16
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        nsentences, ntokens = sample["nsentences"], sample["ntokens"]

        # B x T
        src_tokens, src_lengths = (
            sample["net_input"]["src_tokens"],
            sample["net_input"]["src_lengths"],
        )
        tgt_tokens, prev_output_tokens = sample["target"], sample[
            "prev_target"]

        outputs = model(src_tokens, src_lengths, prev_output_tokens,
                        tgt_tokens)
        losses, nll_loss = [], []

        for obj in outputs:
            if outputs[obj].get("loss", None) is None:
                _losses = self._compute_loss(outputs[obj].get("out"),
                                             outputs[obj].get("tgt"),
                                             outputs[obj].get("mask", None),
                                             outputs[obj].get("ls", 0.0),
                                             name=obj + '-loss',
                                             factor=outputs[obj].get(
                                                 "factor", 1.0))
            else:
                _losses = self._custom_loss(outputs[obj].get("loss"),
                                            name=obj + '-loss',
                                            factor=outputs[obj].get(
                                                "factor", 1.0))

            losses += [_losses]
            if outputs[obj].get("nll_loss", False):
                nll_loss += [_losses.get("nll_loss", 0.0)]

        loss = sum(l["loss"] for l in losses)
        nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 \
            else loss.new_tensor(0)

        # NOTE:
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            "loss": loss.data,
            "nll_loss": nll_loss.data,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }

        for l in losses:
            logging_output[l["name"]] = (utils.item(
                l["loss"].data / l["factor"]) if reduce else l[["loss"]].data /
                                         l["factor"])

        return loss, sample_size, logging_output