Beispiel #1
0
    def __init__(
        self,
        task,
        silence_token,
        asg_transitions_init,
        max_replabel,
        linseg_updates,
        hide_linseg_messages,
    ):
        from wav2letter.criterion import ASGLoss, CriterionScaleMode

        super().__init__(task)
        self.tgt_dict = task.target_dictionary
        self.eos = self.tgt_dict.eos()
        self.silence = (self.tgt_dict.index(silence_token)
                        if silence_token in self.tgt_dict else None)
        self.max_replabel = max_replabel

        num_labels = len(self.tgt_dict)
        self.asg = ASGLoss(num_labels,
                           scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
        self.asg.trans = torch.nn.Parameter(asg_transitions_init *
                                            torch.eye(num_labels),
                                            requires_grad=True)

        self.linseg_progress = torch.nn.Parameter(torch.tensor(
            [0], dtype=torch.int),
                                                  requires_grad=False)
        self.linseg_maximum = linseg_updates
        self.linseg_message_state = "none" if hide_linseg_messages else "start"
Beispiel #2
0
class ASGCriterion(FairseqCriterion):
    @staticmethod
    def add_args(parser):
        group = parser.add_argument_group("ASG Loss")
        group.add_argument(
            "--asg-transitions-init",
            help="initial diagonal value of transition matrix",
            type=float,
            default=0.0,
        )
        group.add_argument("--max-replabel",
                           help="maximum # of replabels",
                           type=int,
                           default=2)
        group.add_argument(
            "--linseg-updates",
            help="# of training updates to use LinSeg initialization",
            type=int,
            default=0,
        )
        group.add_argument(
            "--hide-linseg-messages",
            help="hide messages about LinSeg initialization",
            action="store_true",
        )

    def __init__(
        self,
        task,
        silence_token,
        asg_transitions_init,
        max_replabel,
        linseg_updates,
        hide_linseg_messages,
    ):
        from wav2letter.criterion import ASGLoss, CriterionScaleMode

        super().__init__(task)
        self.tgt_dict = task.target_dictionary
        self.eos = self.tgt_dict.eos()
        self.silence = (self.tgt_dict.index(silence_token)
                        if silence_token in self.tgt_dict else None)
        self.max_replabel = max_replabel

        num_labels = len(self.tgt_dict)
        self.asg = ASGLoss(num_labels,
                           scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
        self.asg.trans = torch.nn.Parameter(asg_transitions_init *
                                            torch.eye(num_labels),
                                            requires_grad=True)

        self.linseg_progress = torch.nn.Parameter(torch.tensor(
            [0], dtype=torch.int),
                                                  requires_grad=False)
        self.linseg_maximum = linseg_updates
        self.linseg_message_state = "none" if hide_linseg_messages else "start"

    @classmethod
    def build_criterion(cls, args, task):
        return cls(
            task,
            args.silence_token,
            args.asg_transitions_init,
            args.max_replabel,
            args.linseg_updates,
            args.hide_linseg_messages,
        )

    def linseg_step(self):
        if not self.training:
            return False
        if self.linseg_progress.item() < self.linseg_maximum:
            if self.linseg_message_state == "start":
                print("| using LinSeg to initialize ASG")
                self.linseg_message_state = "finish"
            self.linseg_progress.add_(1)
            return True
        elif self.linseg_message_state == "finish":
            print("| finished LinSeg initialization")
            self.linseg_message_state = "none"
        return False

    def replace_eos_with_silence(self, tgt):
        if tgt[-1] != self.eos:
            return tgt
        elif self.silence is None or (len(tgt) > 1
                                      and tgt[-2] == self.silence):
            return tgt[:-1]
        else:
            return tgt[:-1] + [self.silence]

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        net_output = model(**sample["net_input"])
        emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
        B = emissions.size(0)
        T = emissions.size(1)
        device = emissions.device

        target = torch.IntTensor(B, T)
        target_size = torch.IntTensor(B)
        using_linseg = self.linseg_step()

        for b in range(B):
            initial_target_size = sample["target_lengths"][b].item()
            if initial_target_size == 0:
                raise ValueError("target size cannot be zero")

            tgt = sample["target"][b, :initial_target_size].tolist()
            tgt = self.replace_eos_with_silence(tgt)
            tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
            tgt = tgt[:T]

            if using_linseg:
                tgt = [tgt[t * len(tgt) // T] for t in range(T)]

            target[b][:len(tgt)] = torch.IntTensor(tgt)
            target_size[b] = len(tgt)

        loss = self.asg.forward(emissions, target.to(device),
                                target_size.to(device))

        if reduce:
            loss = torch.sum(loss)

        sample_size = (sample["target"].size(0)
                       if self.args.sentence_avg else sample["ntokens"])
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        agg_output = {
            "loss": loss_sum / nsentences,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }
        return agg_output
Beispiel #3
0
    parser.add_argument("--cpu",
                        action="store_true",
                        help="Use cpu backend, otherwise use CUDA backend")
    parser.add_argument(
        "--double",
        action="store_true",
        help="store tensors in double, otherwise in float",
    )
    args = parser.parse_args()

    device = torch.device("cpu" if args.cpu else "cuda")
    float_type = torch.double if args.double else torch.float

    # create ASG loss with scaling the loss to the sqrt of target size
    # and 6 tokens (6 tokens scores predicted by some network for each frame)
    asg = ASGLoss(6, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT).to(device)
    # define the input to the loss (scores for tokens came from
    # some network for each frame) size is [batch, time, ntokens]
    input = torch.tensor(
        [
            [
                [-0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707],
                [0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801],
                [0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560],
                [0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526],
                [-0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299],
            ],
            [
                [0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543],
                [-0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739],
                [-0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967],