Ejemplo n.º 1
0
    def __call__(self, logits, elens, ys, ylens):
        """Forced alignment with references.

        Args:
            logits (FloatTensor): `[B, T, vocab]`
            elens (List): length `[B]`
            ys (List): length `[B]`, each of which contains a list of size `[L]`
            ylens (List): length `[B]`
        Returns:
            trigger_points (IntTensor): `[B, L]`

        """
        with torch.no_grad():
            ys = [
                np2tensor(np.fromiter(y, dtype=np.int64), logits.device)
                for y in ys
            ]
            ys_in_pad = pad_list(ys, 0)

            # zero padding
            mask = make_pad_mask(elens.to(logits.device))
            mask = mask.unsqueeze(2).expand_as(logits)
            logits = logits.masked_fill_(mask == 0, self.log0)
            log_probs = torch.log_softmax(logits, dim=-1).transpose(
                0, 1)  # `[T, B, vocab]`

            trigger_points = self.align(log_probs, elens, ys_in_pad, ylens)
        return trigger_points
Ejemplo n.º 2
0
    def forward(self, xs, xlens):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim (+Δ, ΔΔ)]`
            xlens (IntTensor): `[B]`
        Returns:
            xs (FloatTensor): `[B, T, input_dim]`

        """
        residual = xs
        xs = self.layers(xs)  # `[B, T, input_dim]`

        # padding
        device_id = torch.cuda.device_of(next(self.parameters())).idx
        mask = make_pad_mask(xlens, device_id).unsqueeze(2)  # `[B, T, 1]`
        xs = xs.clone().masked_fill_(mask == 0, 0)

        # time average
        denom = xlens.float().unsqueeze(1)
        if device_id >= 0:
            denom = denom.cuda(device_id)
        xs = xs.sum(1) / denom
        xs = residual + self.proj(xs).unsqueeze(1)
        return xs
Ejemplo n.º 3
0
    def forward(self, xs, xlens):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim (+Δ, ΔΔ)]`
            xlens (IntTensor): `[B]`
        Returns:
            xs (FloatTensor): `[B, T', input_dim]`

        """
        bs, time = xs.size()[:2]

        s = xs.clone()
        for l in range(self.n_layers - 1):
            s = torch.tanh(self.ssn[l](s))
        s = self.ssn[self.n_layers - 1](s)  # `[B, T, input_dim]`

        # padding
        device_id = torch.cuda.device_of(next(self.parameters())).idx
        mask = make_pad_mask(xlens, device_id).unsqueeze(2)
        s = s.masked_fill_(mask == 0, 0)

        # time average
        s = s.sum(1) / xlens.float().cuda(device_id).unsqueeze(1)
        xs = xs + self.p(s).unsqueeze(1)
        return xs
Ejemplo n.º 4
0
    def forward(self, xs, xlens, task):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)

        # Create the self-attention mask
        bs, xmax = xs.size()[:2]
        xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(1).expand(
            bs, xmax, xmax)
        xx_mask = xx_mask.unsqueeze(1).expand(bs, self.attn_n_heads, xmax,
                                              xmax)

        xs = self.pos_enc(xs)
        for l in range(self.n_layers):
            xs, xx_aws = self.layers[l](xs, xx_mask)
            if not self.training:
                setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws))
        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        eouts['ys']['xs'] = xs
        eouts['ys']['xlens'] = xlens
        return eouts
Ejemplo n.º 5
0
def time_restricted_mask(xs, xlens, N_l, N_c, N_r, n_chunks):
    xx_mask = make_pad_mask(xlens.to(xs.device))
    xx_mask = xx_mask.unsqueeze(1).repeat(
        [1, xs.size(1), 1])  # `[B, emax (query), emax (key)]`
    xx_mask_first = xx_mask.clone()
    for chunk_idx in range(n_chunks):
        offset = chunk_idx * N_c
        # for first layer
        xx_mask_first[:, offset:offset + N_c, :max(0, offset - N_l)] = 0
        xx_mask_first[:, offset:offset + N_c, offset + (N_c + N_r):] = 0
        # for upper layers
        xx_mask[:, offset:offset + N_c, :max(0, offset - N_l)] = 0
        xx_mask[:, offset:offset + N_c, offset + N_c:] = 0
    return xx_mask_first, xx_mask
Ejemplo n.º 6
0
def make_san_mask(xs, xlens, unidirectional=False, lookahead=0):
    """Mask self-attention mask.

    Args:
        xs (FloatTensor): `[B, T, d_model]`
        xlens (InteTensor): `[B]` (on CPU)
        unidirectional (bool): pad future context
        lookahead (int): lookahead frame
    Returns:
        xx_mask (ByteTensor): `[B, T (query), T (key)]`

    """
    xx_mask = make_pad_mask(xlens.to(xs.device))
    xx_mask = xx_mask.unsqueeze(1).repeat([1, xlens.max(), 1])  # `[B, emax (query), emax (key)]`
    if unidirectional:
        xx_mask = causal(xx_mask, lookahead)
    return xx_mask
Ejemplo n.º 7
0
    def decode(self, ys, state=None, is_asr=False):
        """Decode function.

        Args:
            ys (FloatTensor): `[B, L]`
            state: previous tokens
            is_asr (bool):
        Returns:
            ys_emb (FloatTensor): `[B, L, n_units]`
            state: previous tokens

        """
        # Concatenate previous tokens
        if is_asr and state is not None:
            ys = torch.cat([state, ys], dim=1)
            # NOTE: this is used for ASR decoding

        ys_emb = self.embed(ys.long())

        # Create the self-attention mask
        bs, ymax = ys_emb.size()[:2]
        ylens = torch.IntTensor([ymax] * bs)
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        ys_emb = self.pos_enc(ys_emb)
        for l in range(self.n_layers):
            ys_emb, yy_aws, _ = self.layers[l](ys_emb, yy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
        ys_emb = self.norm_out(ys_emb)

        if is_asr:
            state = ys

        return ys_emb, state
Ejemplo n.º 8
0
    def decode(self, ys, ys_prev=None, cache=False):
        """Decode function.

        Args:
            ys (LongTensor): `[B, L]`
            ys_prev (LongTensor): previous tokens
            cahce (bool): concatenate previous tokens
        Returns:
            logits (FloatTensor): `[B, L, vocab]`
            ys_emb (FloatTensor): `[B, L, d_model]` (for ys_prev)
            ys_prev (LongTensor): previous tokens

        """
        # Concatenate previous tokens
        if cache and ys_prev is not None:
            ys = torch.cat([ys_prev, ys], dim=1)
            # NOTE: this is used for ASR decoding

        # Create the self-attention mask
        bs, ymax = ys.size()[:2]
        ylens = torch.IntTensor([ymax] * bs)
        tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])
        subsequent_mask = tgt_mask.new_ones(ymax, ymax).byte()
        subsequent_mask = torch.tril(subsequent_mask,
                                     out=subsequent_mask).unsqueeze(0)
        tgt_mask = tgt_mask & subsequent_mask

        out = self.pos_enc(self.embed(ys.long()))
        for l in range(self.n_layers):
            out, yy_aws, _ = self.layers[l](out, tgt_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
        out = self.norm_out(out)
        if self.adaptive_softmax is None:
            logits = self.output(out)
        else:
            logits = out

        return logits, out, ys
Ejemplo n.º 9
0
    def forward(self, xs, xlens):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim (+Δ, ΔΔ)]`
            xlens (IntTensor): `[B]`
        Returns:
            xs (FloatTensor): `[B, T, input_dim]`

        """
        residual = xs
        xs = self.layers(xs)  # `[B, T, input_dim]`

        # padding
        xlens = xlens.to(xs.device)
        mask = make_pad_mask(xlens).unsqueeze(2)  # `[B, T, 1]`
        xs = xs.clone().masked_fill_(mask == 0, 0)

        # time average
        denom = xlens.float().unsqueeze(1)
        xs = xs.sum(1) / denom
        xs = residual + self.proj(xs).unsqueeze(1)
        return xs
Ejemplo n.º 10
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float):
            ppl (float):

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        eos = eouts.new_zeros(1).fill_(self.eos).long()
        ys = [
            np2tensor(np.fromiter(y[::-1] if self.bwd else y, dtype=np.int64),
                      self.device_id) for y in ys
        ]
        ylens = np2tensor(
            np.fromiter([y.size(0) + 1 for y in ys],
                        dtype=np.int32))  # +1 for <eos>
        ys_in_pad = pad_list([torch.cat([eos, y], dim=0) for y in ys],
                             self.pad)
        ys_out_pad = pad_list([torch.cat([y, eos], dim=0) for y in ys],
                              self.pad)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).expand(
            bs, ymax, ymax)
        yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, ymax,
                                              ymax)
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, ymax)
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
            bs, ymax, xmax)
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).expand(
            bs, ymax, xmax)
        xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
            bs, self.attn_n_heads, ymax, xmax)

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.norm_out(ys_emb)
        if self.adaptive_softmax is None:
            logits = self.output(logits)
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.adaptive_softmax is None:
            if self.lsm_prob > 0 and self.training:
                # Label smoothing
                loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1), self.lsm_prob,
                                         self.pad)
            else:
                loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                       ys_out_pad.view(-1),
                                       ignore_index=self.pad,
                                       size_average=True)

            # Focal loss
            if self.focal_loss_weight > 0:
                fl = focal_loss(logits,
                                ys_out_pad,
                                ylens,
                                alpha=self.focal_loss_weight,
                                gamma=self.focal_loss_gamma)
                loss = loss * (
                    1 - self.focal_loss_weight) + fl * self.focal_loss_weight
        else:
            loss = self.adaptive_softmax(logits.view((-1, logits.size(2))),
                                         ys_out_pad.view(-1)).loss

        # Compute token-level accuracy in teacher-forcing
        if self.adaptive_softmax is None:
            acc = compute_accuracy(logits, ys_out_pad, self.pad)
        else:
            acc = compute_accuracy(
                self.adaptive_softmax.log_prob(
                    logits.view((-1, logits.size(2)))), ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Ejemplo n.º 11
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        N_l = self.chunk_size_left
        N_c = self.chunk_size_current
        N_r = self.chunk_size_right

        bs, xmax, idim = xs.size()

        if self.latency_controlled:
            xs = chunkwise(xs, N_l, N_c, N_r)

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)

        if not self.training:
            self.data_dict['elens'] = tensor2np(xlens)

        if self.latency_controlled:
            # streaming Conformer encoder
            _N_l = max(0, N_l // self.subsampling_factor)
            _N_c = N_c // self.subsampling_factor

            n_chunks = math.ceil(xs.size(0) / bs)
            emax = math.ceil(xmax / self.subsampling_factor)

            xs = xs * self.scale
            pos_idxs = torch.arange(xs.size(1) - 1,
                                    -1,
                                    -1.0,
                                    dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

            xx_mask = None  # NOTE: no mask
            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    n_heads = layer.xx_aws.size(1)
                    xx_aws = layer.xx_aws[:, :, _N_l:_N_l + _N_c,
                                          _N_l:_N_l + _N_c]
                    xx_aws = xx_aws.view(bs, n_chunks, n_heads, _N_c, _N_c)
                    xx_aws_center = xx_aws.new_zeros(bs, n_heads, emax, emax)
                    for chunk_idx in range(n_chunks):
                        offset = chunk_idx * _N_c
                        emax_blc = xx_aws_center[:, :,
                                                 offset:offset + _N_c].size(2)
                        xx_aws_chunk = xx_aws[:, chunk_idx, :, :emax_blc, :
                                              emax_blc]
                        xx_aws_center[:, :, offset:offset + _N_c,
                                      offset:offset + _N_c] = xx_aws_chunk
                    self.aws_dict['xx_aws_layer%d' %
                                  lth] = tensor2np(xx_aws_center)

            # Extract the center region
            xs = xs[:, _N_l:_N_l + _N_c]  # `[B * n_chunks, _N_c, d_model]`
            xs = xs.contiguous().view(bs, -1, xs.size(2))
            xs = xs[:, :emax]

        else:
            bs, xmax, idim = xs.size()
            xs = xs * self.scale

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            pos_idxs = torch.arange(xmax - 1, -1, -1.0, dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.layer_sub1(
                        xs, xx_mask, pos_embs=pos_embs
                    ) if self.task_specific_layer else xs.clone()
                    xs_sub1 = self.norm_out_sub1(xs_sub1)
                    if self.bridge_sub1 is not None:
                        xs_sub1 = self.bridge_sub1(xs_sub1)
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.layer_sub2(
                        xs, xx_mask, pos_embs=pos_embs
                    ) if self.task_specific_layer else xs.clone()
                    xs_sub2 = self.norm_out_sub2(xs_sub2)
                    if self.bridge_sub2 is not None:
                        xs_sub2 = self.bridge_sub2(xs_sub2)
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
Ejemplo n.º 12
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.latency_controlled:
            bs, xmax, idim = xs.size()
            n_blocks = xmax // self.N_c
            if xmax % self.N_c != 0:
                n_blocks += 1
            xs_tmp = xs.new_zeros(bs, n_blocks, self.N_l + self.N_c + self.N_r,
                                  idim)
            xs_pad = torch.cat([
                xs.new_zeros(bs, self.N_l, idim), xs,
                xs.new_zeros(bs, self.N_r, idim)
            ],
                               dim=1)
            for blc_id, t in enumerate(
                    range(self.N_l, self.N_l + xmax, self.N_c)):
                xs_chunk = xs_pad[:, t - self.N_l:t + (self.N_c + self.N_r)]
                xs_tmp[:, blc_id, :xs_chunk.size(1), :] = xs_chunk
            xs = xs_tmp.view(bs * n_blocks, self.N_l + self.N_c + self.N_r,
                             idim)

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)
        if not self.training:
            self.data_dict['elens'] = tensor2np(xlens)

        if self.latency_controlled:
            N_l = max(0, self.N_l // self.subsampling_factor)
            N_c = self.N_c // self.subsampling_factor

            emax = xmax // self.subsampling_factor
            if xmax % self.subsampling_factor != 0:
                emax += 1

            xs = self.pos_enc(xs, scale=True)
            xx_mask = None
            for lth, layer in enumerate(self.layers):
                xs, xx_aws = layer(xs, xx_mask)
                if not self.training:
                    n_heads = xx_aws.size(1)
                    xx_aws = xx_aws[:, :, N_l:N_l + N_c, N_l:N_l + N_c]
                    xx_aws = xx_aws.view(bs, n_blocks, n_heads, N_c, N_c)
                    xx_aws_center = xx_aws.new_zeros(bs, n_heads, emax, emax)
                    for blc_id in range(n_blocks):
                        offset = blc_id * N_c
                        emax_blc = xx_aws_center[:, :,
                                                 offset:offset + N_c].size(2)
                        xx_aws_chunk = xx_aws[:,
                                              blc_id, :, :emax_blc, :emax_blc]
                        xx_aws_center[:, :, offset:offset + N_c,
                                      offset:offset + N_c] = xx_aws_chunk
                    self.aws_dict['xx_aws_layer%d' %
                                  lth] = tensor2np(xx_aws_center)

            # Extract the center region
            xs = xs[:, N_l:N_l +
                    N_c]  # `[B * n_blocks, N_c // subsampling_factor, d_model]`
            xs = xs.contiguous().view(bs, -1, xs.size(2))
            xs = xs[:, :emax]

        else:
            bs, xmax, idim = xs.size()
            xs = self.pos_enc(xs, scale=True)

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            for lth, layer in enumerate(self.layers):
                xs, xx_aws = layer(xs, xx_mask)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.layer_sub1(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub1 = self.norm_out_sub1(xs_sub1)
                    if self.bridge_sub1 is not None:
                        xs_sub1 = self.bridge_sub1(xs_sub1)
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.layer_sub2(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub2 = self.norm_out_sub2(xs_sub2)
                    if self.bridge_sub2 is not None:
                        xs_sub2 = self.bridge_sub2(xs_sub2)
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
Ejemplo n.º 13
0
    def align(self, logits, elens, ys, ylens, add_eos=True):
        """Calculte the best CTC alignment with the forward-backward algorithm.
        Args:
            logits (FloatTensor): `[B, T, vocab]`
            elens (FloatTensor): `[B]`
            ys (FloatTensor): `[B, L]`
            ylens (FloatTensor): `[B]`
            add_eos (bool): Use the last time index as a boundary corresponding to <eos>
        Returns:
            trigger_points (IntTensor): `[B, L]`

        """
        bs, xmax, vocab = logits.size()
        device = logits.device

        # zero padding
        mask = make_pad_mask(elens.to(device))
        mask = mask.unsqueeze(2).repeat([1, 1, vocab])
        logits = logits.masked_fill_(mask == 0, self.log0)
        log_probs = torch.log_softmax(logits, dim=-1).transpose(0, 1)  # `[T, B, vocab]`

        path = _label_to_path(ys, self.blank)
        path_lens = 2 * ylens.long() + 1

        ymax = ys.size(1)
        max_path_len = path.size(1)
        assert ys.size() == (bs, ymax), ys.size()
        assert path.size() == (bs, ymax * 2 + 1)

        alpha = log_probs.new_zeros(bs, max_path_len).fill_(self.log0)
        alpha[:, 0] = LOG_1
        beta = alpha.clone()
        gamma = alpha.clone()

        batch_index = torch.arange(bs, dtype=torch.int64).unsqueeze(1)
        seq_index = torch.arange(xmax, dtype=torch.int64).unsqueeze(1).unsqueeze(2)
        log_probs_fwd_bwd = log_probs[seq_index, batch_index, path]

        # forward algorithm
        for t in range(xmax):
            alpha = self._computes_transition(alpha, path, path_lens, log_probs_fwd_bwd[t], log_probs[t])

        # backward algorithm
        r_path = _flip_path(path, path_lens)
        log_probs_inv = _flip_label_probability(log_probs, elens.long())  # `[T, B, vocab]`
        log_probs_fwd_bwd = _flip_path_probability(log_probs_fwd_bwd, elens.long(), path_lens)  # `[T, B, 2*L+1]`
        for t in range(xmax):
            beta = self._computes_transition(beta, r_path, path_lens, log_probs_fwd_bwd[t], log_probs_inv[t])

        # pick up the best CTC path
        best_aligns = log_probs.new_zeros((bs, xmax), dtype=torch.int64)

        # forward algorithm
        log_probs_fwd_bwd = _flip_path_probability(log_probs_fwd_bwd, elens.long(), path_lens)
        for t in range(xmax):
            gamma = self._computes_transition(gamma, path, path_lens, log_probs_fwd_bwd[t], log_probs[t],
                                              skip_accum=True)

            # select paths where gamma is valid
            log_probs_fwd_bwd[t] = log_probs_fwd_bwd[t].masked_fill_(gamma == self.log0, self.log0)

            # pick up the best alignment
            offsets = log_probs_fwd_bwd[t].argmax(1)
            for b in range(bs):
                if t <= elens[b] - 1:
                    token_idx = path[b, offsets[b]]
                    best_aligns[b, t] = token_idx

            # remove the rest of paths
            gamma = log_probs.new_zeros(bs, max_path_len).fill_(self.log0)
            for b in range(bs):
                gamma[b, offsets[b]] = LOG_1

        # pick up trigger points
        trigger_aligns = torch.zeros((bs, xmax), dtype=torch.int64)
        trigger_points = log_probs.new_zeros((bs, ymax + 1), dtype=torch.int32)  # +1 for <eos>
        for b in range(bs):
            n_triggers = 0
            if add_eos:
                trigger_points[b, ylens[b]] = elens[b] - 1
                # NOTE: use the last time index as a boundary corresponding to <eos>
                # Otherwise, index: 0 is used for <eos>
            for t in range(elens[b]):
                token_idx = best_aligns[b, t]
                if token_idx == self.blank:
                    continue
                if not (t == 0 or token_idx != best_aligns[b, t - 1]):
                    continue

                # NOTE: select the most left trigger points
                trigger_aligns[b, t] = token_idx
                trigger_points[b, n_triggers] = t
                n_triggers += 1

        assert ylens.sum() == (trigger_aligns != 0).sum()
        return trigger_points
Ejemplo n.º 14
0
    def forward_att(self, eouts, elens, ys, return_logits=False):
        """Compute XE loss for the sequence-to-sequence model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        ys_in_pad, ys_out_pad, ylens = self.append_sos_eos(ys, self.bwd)

        # Create the self-attention mask
        bs, ymax = ys_in_pad.size()[:2]
        yy_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])
        yy_mask = yy_mask.unsqueeze(1).repeat([1, self.attn_n_heads, 1, 1])
        subsequent_mask = torch.tril(yy_mask.new_ones((ymax, ymax)).byte(),
                                     diagonal=0)
        subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).repeat(
            [bs, self.attn_n_heads, 1, 1])
        yy_mask = yy_mask & subsequent_mask

        # Create the source-target mask
        xmax = eouts.size(1)
        x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])
        y_mask = make_pad_mask(ylens, self.device_id).unsqueeze(2).repeat(
            [1, 1, xmax])
        xy_mask = (x_mask * y_mask).unsqueeze(1).repeat(
            [1, self.attn_n_heads, 1, 1])

        ys_emb = self.pos_enc(self.embed(ys_in_pad))
        for l in range(self.n_layers):
            ys_emb, yy_aws, xy_aws = self.layers[l](ys_emb, yy_mask, eouts,
                                                    xy_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        ys_emb = self.norm_out(ys_emb)
        logits = self.output(ys_emb)

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE sequence loss
        if self.lsm_prob > 0 and self.training:
            # Label smoothing
            loss = cross_entropy_lsm(logits.view((-1, logits.size(2))),
                                     ys_out_pad.view(-1), self.lsm_prob,
                                     self.pad)
        else:
            loss = F.cross_entropy(logits.view((-1, logits.size(2))),
                                   ys_out_pad.view(-1),
                                   ignore_index=self.pad,
                                   size_average=True)

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out_pad, self.pad)
        ppl = min(np.exp(loss.item()), np.inf)

        # scale loss for CTC
        loss *= ylens.float().mean()

        return loss, acc, ppl
Ejemplo n.º 15
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)
        if not self.training:
            self.data_dict['elens'] = tensor2np(xlens)

        bs, xmax, idim = xs.size()
        if self.memory_transformer or self.latency_controlled:
            # streaming Transformer(XL) encoder
            N_l = max(0, self.chunk_size_left // self.subsampling_factor)
            N_c = self.chunk_size_cur // self.subsampling_factor
            N_r = max(0, self.chunk_size_right // self.subsampling_factor)

            xs_chunks = []
            xx_aws = [[] for _ in range(self.n_layers)]
            mems = self.init_memory()
            self.reset_cache()  # for LC-BLSTM

            mlen = 0
            for t in range(0, xmax, N_c):
                clen = min(N_c, xmax - 1 - t + 1)
                rlen = 0
                if xmax - 1 - (t + clen) + 1 > 0:
                    rlen = min(N_r, xmax - 1 - (t + clen) + 1)

                xs_chunk = xs[:, t:t + (clen + rlen)]

                if self.hybrid_rnn:
                    for lth in range(self.n_layers_rnn):
                        self.rnn[lth].flatten_parameters()  # for multi-GPUs
                        self.rnn_bwd[lth].flatten_parameters(
                        )  # for multi-GPUs
                        # bwd
                        xs_chunk_bwd = torch.flip(xs_chunk, dims=[1])
                        xs_chunk_bwd, _ = self.rnn_bwd[lth](xs_chunk_bwd,
                                                            hx=None)
                        xs_chunk_bwd = torch.flip(
                            xs_chunk_bwd,
                            dims=[1])  # `[B, clen+rlen, d_model]`
                        # fwd
                        if xs_chunk.size(1) <= clen:
                            xs_chunk_fwd, self.fwd_states[lth] = self.rnn[lth](
                                xs_chunk, hx=self.fwd_states[lth])
                        else:
                            xs_chunk_fwd1, self.fwd_states[lth] = self.rnn[
                                lth](xs_chunk[:, :clen],
                                     hx=self.fwd_states[lth])
                            xs_chunk_fwd2, _ = self.rnn[lth](
                                xs_chunk[:, clen:], hx=self.fwd_states[lth])
                            xs_chunk_fwd = torch.cat(
                                [xs_chunk_fwd1, xs_chunk_fwd2],
                                dim=1)  # `[B, clen+rlen, d_model]`
                            # NOTE: xs_chunk_fwd2 is for xs_chunk_bwd in the next layer
                        if self.bidir_sum:
                            xs_chunk = xs_chunk_fwd + xs_chunk_bwd
                        else:
                            xs_chunk = torch.cat([xs_chunk_fwd, xs_chunk_bwd],
                                                 dim=-1)
                    xs_chunk = self.dropout_rnn(xs_chunk)
                    if self.proj is not None:
                        xs_chunk = self.proj(xs_chunk)

                xs_chunk = self.pos_enc(xs_chunk, scale=True)  # for scale

                if self.memory_transformer:
                    # adopt zero-centered offset
                    pos_idxs = torch.arange(mlen - 1,
                                            -xs_chunk.size(1) - 1,
                                            -1.0,
                                            dtype=torch.float)
                    pos_embs = self.pos_emb(pos_idxs, self.device_id)

                hidden_states = [xs_chunk[:, :clen][:, -N_l:]]
                for lth, (mem, layer) in enumerate(zip(mems, self.layers)):
                    if self.memory_transformer:
                        xs_chunk, xx_aws_chunk = layer(xs_chunk,
                                                       None,
                                                       pos_embs=pos_embs,
                                                       memory=mem,
                                                       u=self.u,
                                                       v=self.v)  # no mask
                    else:
                        xs_chunk, xx_aws_chunk = layer(xs_chunk,
                                                       None,
                                                       memory=mem)  # no mask

                    if lth < self.n_layers - 1:
                        hidden_states.append(xs_chunk[:, :clen][:, -N_l:])
                    # NOTE: xx_aws_chunk: `[B, H, clen+rlen (query), mlen+clen+rlen (key)]`
                    xx_aws_chunk = xx_aws_chunk[:, :, :clen, mlen:mlen + clen]
                    assert xx_aws_chunk.size(2) == xx_aws_chunk.size(3)
                    xx_aws_chunk_pad = xs.new_zeros(
                        (bs, xx_aws_chunk.size(1), N_c, N_c))
                    xx_aws_chunk_pad[:, :, :xx_aws_chunk.size(2), :xx_aws_chunk
                                     .size(3)] = xx_aws_chunk
                    xx_aws[lth].append(xx_aws_chunk_pad)
                mems = self.update_memory(mems, hidden_states)
                mlen = mems[0].size(1) if mems[0].dim() > 1 else 0
                xs_chunks.append(xs_chunk[:, :clen])
            xs = torch.cat(xs_chunks, dim=1)[:, :xmax]

            if not self.training:
                for lth in range(self.n_layers):
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        torch.cat(xx_aws[lth], dim=3)[:, :, :xmax, :xmax])

        else:
            # Hybrid RNN-Transformer
            if self.hybrid_rnn:
                for lth in range(self.n_layers_rnn):
                    self.rnn[lth].flatten_parameters()  # for multi-GPUs
                    self.rnn_bwd[lth].flatten_parameters()  # for multi-GPUs
                    # bwd
                    xs_bwd = torch.flip(xs, dims=[1])
                    xs_bwd, _ = self.rnn_bwd[lth](xs_bwd, hx=None)
                    xs_bwd = torch.flip(xs_bwd, dims=[1])
                    # fwd
                    xs_fwd, _ = self.rnn[lth](xs, hx=None)
                    # NOTE: no padding because inputs are not sorted
                    if self.bidir_sum:
                        xs = xs_fwd + xs_bwd
                    else:
                        xs = torch.cat([xs_fwd, xs_bwd], dim=-1)
                    xs = self.dropout_rnn(xs)
                if self.proj is not None:
                    xs = self.proj(xs)

            xs = self.pos_enc(xs, scale=True)

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            for lth, layer in enumerate(self.layers):
                xs, xx_aws = layer(xs, xx_mask)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(xx_aws)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.layer_sub1(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub1 = self.norm_out_sub1(xs_sub1)
                    if self.bridge_sub1 is not None:
                        xs_sub1 = self.bridge_sub1(xs_sub1)
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.layer_sub2(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub2 = self.norm_out_sub2(xs_sub2)
                    if self.bridge_sub2 is not None:
                        xs_sub2 = self.bridge_sub2(xs_sub2)
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
Ejemplo n.º 16
0
    def forward_att(self,
                    eouts,
                    elens,
                    ys,
                    return_logits=False,
                    teacher_logits=None,
                    trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
            trigger_points (IntTensor): `[B, T]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            loss_quantity (FloatTensor): `[1]`
            loss_headdiv (FloatTensor): `[1]`
            loss_latency (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos,
                                              self.pad, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        xtime = eouts.size(1)
        bs, ymax = ys_in.size()[:2]
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax).byte()
        causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L, L]`

        # Create source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat(
            [1, ymax, 1])  # `[B, L, T]`

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.embed(ys_in)

        mlen = 0  # TODO: fix later
        if self.memory_transformer:
            # NOTE: TransformerXL does not use positional encoding in the token embedding
            mems = self.init_memory()
            # adopt zero-centered offset
            pos_idxs = torch.arange(mlen - 1,
                                    -ymax - 1,
                                    -1.0,
                                    dtype=torch.float)
            if self.device_id >= 0:
                pos_idxs = pos_idxs.cuda(self.device_id)
            pos_embs = self.dropout_emb(self.pos_emb(pos_idxs))
            out = self.dropout_emb(out)
            hidden_states = [out]
        else:
            out = self.pos_enc(out)

        xy_aws_layers = []
        for l, layer in enumerate(self.layers):
            if self.memory_transformer:
                out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer(
                    out,
                    tgt_mask,
                    eouts,
                    src_mask,
                    mode='parallel',
                    lmout=lmout,
                    pos_embs=pos_embs,
                    memory=mems[l],
                    u=self.u,
                    v=self.v)
                hidden_states.append(out)
            else:
                out, yy_aws, xy_aws, xy_aws_beta, yy_aws_lm = layer(
                    out,
                    tgt_mask,
                    eouts,
                    src_mask,
                    mode='parallel',
                    lmout=lmout)

            xy_aws_layers.append(xy_aws.clone() if xy_aws is not None else out.
                                 new_zeros(bs, yy_aws.size(1), ymax, xtime))
            if not self.training:
                if yy_aws is not None:
                    self.aws_dict['yy_aws_layer%d' % l] = tensor2np(yy_aws)
                if xy_aws is not None:
                    self.aws_dict['xy_aws_layer%d' % l] = tensor2np(xy_aws)
                if xy_aws_beta is not None:
                    self.aws_dict['xy_aws_beta_layer%d' %
                                  l] = tensor2np(xy_aws_beta)
                if yy_aws_lm is not None:
                    self.aws_dict['yy_aws_lm_layer%d' %
                                  l] = tensor2np(yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # TODO: Update memory
        # if self.memory_transformer:
        #     new_mems = self.update_memory(mems, hidden_states)

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE sequence loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad,
                                      self.training)

        # Attention padding
        if self.quantity_loss_weight > 0 or self.headdiv_loss_weight > 0 or self.latency_loss_weight > 0:
            for l in range(self.mocha_first_layer - 1, self.n_layers):
                n_heads = xy_aws_layers[l].size(1)
                xy_aws_layers[l] = xy_aws_layers[l].masked_fill_(
                    src_mask.unsqueeze(1).repeat([1, n_heads, 1, 1]) == 0, 0)
                xy_aws_layers[l] = xy_aws_layers[l].masked_fill_(
                    tgt_mask[:, :,
                             -1:].unsqueeze(1).repeat([1, n_heads, 1,
                                                       xtime]) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
        n_heads = xy_aws_layers[-1].size(1)  # mono
        # NOTE: debug for multihead mono + multihead chunk

        # Quantity loss
        loss_quantity = 0.
        if 'mocha' in self.attn_type:
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([
                torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                for aws in xy_aws_layers[self.mocha_first_layer - 1:]
            ])  # `[B]`
            n_tokens_pred /= (self.n_layers - self.mocha_first_layer + 1)
            loss_quantity = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Head divergence loss
        loss_headdiv = 0.
        if self.headdiv_loss_weight > 0.:
            # Calculate variance over all heads across all layers
            js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id)
            js = js.repeat([bs, n_heads, ymax, 1])
            avg_head_pos = sum([
                (js * aws).sum(3).sum(1) for aws in xy_aws_layers
            ]) / (n_heads * self.n_layers)  # `[B, L]`
            loss_headdiv = sum([((js * aws).sum(3).sum(1) - avg_head_pos)**2
                                for aws in xy_aws_layers]) / (
                                    n_heads * self.n_layers)  # `[B, L]`
            loss_headdiv = loss_headdiv.sum() / ylens.sum()

        # Latency loss
        loss_latency = 0.
        if self.latency_metric == 'interval':
            raise NotImplementedError
        elif trigger_points is not None:
            assert self.latency_loss_weight > 0
            # Calculate weight average latency
            js = torch.arange(xtime, dtype=torch.float).cuda(self.device_id)
            js = js.repeat([bs, n_heads, ymax, 1])
            weighted_avg_head_pos = torch.cat(
                [(js * aws).sum(3) for aws in xy_aws_layers],
                dim=1)  # `[B, H_mono * n_layers, L]`
            weighted_avg_head_pos *= torch.softmax(
                weighted_avg_head_pos.clone(), dim=1)
            trigger_points = trigger_points.float().cuda(
                self.device_id)  # `[B, L]`
            trigger_points = trigger_points.unsqueeze(1)
            if self.latency_metric == 'ctc_sync':
                loss_latency = torch.abs(
                    weighted_avg_head_pos -
                    trigger_points)  # `[B, H_mono * n_layers, L]`
            else:
                raise NotImplementedError(self.latency_metric)
            # NOTE: trigger_points are padded with 0
            loss_latency = loss_latency.sum() / ylens.sum()

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, loss_quantity, loss_headdiv, loss_latency
Ejemplo n.º 17
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (InteTensor): `[B]` (on CPU)
            task (str): ys/ys_sub1/ys_sub2
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (InteTensor): `[B]` (on CPU)

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        N_l = self.chunk_size_left
        N_c = self.chunk_size_current
        N_r = self.chunk_size_right
        bs, xmax, idim = xs.size()
        n_chunks = 0

        if self.latency_controlled:
            if self.lc_type == 'reshape':
                xs = chunkwise(xs, N_l, N_c,
                               N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
            elif self.lc_type == 'mask':
                # xs = chunkwise(xs, N_l, N_c, N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
                xs = chunkwise(xs, 0, N_c, 0)  # `[B * n_chunks, N_c, idim]`
            else:
                raise ValueError
            n_chunks = xs.size(0) // bs

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)
            N_l = max(0, N_l // self.conv.subsampling_factor)
            N_c = N_c // self.conv.subsampling_factor
            N_r = N_r // self.conv.subsampling_factor

        if self.lc_type == 'mask':
            # Extract the center region
            emax = xlens.max().item()
            # xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
            xs = xs.contiguous().view(bs, -1, xs.size(2))
            xs = xs[:, :emax]  # `[B, emax, d_model]`

        if self.latency_controlled:
            # streaming Transformer encoder
            emax = xlens.max().item()

            pos_embs = None
            if self.pe_type == 'relative':
                xs = xs * self.scale
                pos_idxs = torch.arange(xs.size(1) - 1,
                                        -1,
                                        -1.0,
                                        dtype=torch.float,
                                        device=self.device)
                pos_embs = self.pos_emb(pos_idxs)
            else:
                xs = self.pos_enc(xs, scale=True)

            xx_mask = None  # NOTE: no mask to avoid all masked region
            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    if self.lc_type == 'reshape':
                        n_heads = layer.xx_aws.size(1)
                        xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c,
                                              N_l:N_l + N_c]
                        xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c)
                        xx_aws_center = xx_aws.new_zeros(
                            bs, n_heads, emax, emax)
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            emax_chunk = xx_aws_center[:, :, offset:offset +
                                                       N_c].size(2)
                            xx_aws_chunk = xx_aws[:, chunk_idx, :, :
                                                  emax_chunk, :emax_chunk]
                            xx_aws_center[:, :, offset:offset + N_c,
                                          offset:offset + N_c] = xx_aws_chunk
                    elif self.lc_type == 'mask':
                        self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                            layer.xx_aws)
                    else:
                        raise ValueError
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    emax = xlens.max().item()
                    N_l = max(0, N_l // self.subsample[lth].subsampling_factor)
                    N_c = N_c // self.subsample[lth].subsampling_factor
                    N_r = N_r // self.subsample[lth].subsampling_factor
                    if self.lc_type == 'mask':
                        xx_mask = make_pad_mask(xlens.to(self.device))
                        xx_mask = xx_mask.unsqueeze(1).repeat(
                            [1, xs.size(1),
                             1])  # `[B, emax (query), emax (key)]`
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            xx_mask[:, offset:offset +
                                    N_c, :max(0, offset - N_l)] = 0
                            xx_mask[:, offset:offset + N_c,
                                    offset + (N_c + N_r):] = 0

            # Extract the center region
            if self.lc_type == 'reshape':
                xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
                xs = xs.contiguous().view(bs, -1, xs.size(2))
                xs = xs[:, :emax]

        else:
            if self.pe_type == 'relative':
                xs = xs * self.scale
                # Create sinusoidal positional embeddings for relative positional encoding
                pos_idxs = torch.arange(xs.size(1) - 1,
                                        -1,
                                        -1.0,
                                        dtype=torch.float,
                                        device=self.device)
                pos_embs = self.pos_emb(pos_idxs)
            else:
                xs = self.pos_enc(xs, scale=True)
                pos_embs = None

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens.to(self.device)).unsqueeze(1).repeat(
                [1, xs.size(1), 1])

            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub1')
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub2')
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    # Create the self-attention mask
                    xx_mask = make_pad_mask(xlens.to(
                        self.device)).unsqueeze(1).repeat([1, xs.size(1), 1])
                    if self.pe_type == 'relative':
                        # Create sinusoidal positional embeddings for relative positional encoding
                        pos_idxs = torch.arange(xs.size(1) - 1,
                                                -1,
                                                -1.0,
                                                dtype=torch.float,
                                                device=self.device)
                        pos_embs = self.pos_emb(pos_idxs)

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
Ejemplo n.º 18
0
    def forward_att(self, eouts, elens, ys,
                    return_logits=False, teacher_logits=None, trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): length `B`, each of which contains a list of size `[L]`
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
            trigger_points (IntTensor): `[B, T]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            loss_quantity (FloatTensor): `[1]`
            loss_headdiv (FloatTensor): `[1]`
            loss_latency (FloatTensor): `[1]`

        """
        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.eos, self.pad, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        xmax = eouts.size(1)
        bs, ymax = ys_in.size()[:2]
        mlen = 0
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax).byte()
        causal_mask = torch.tril(causal_mask, diagonal=0 + mlen, out=causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L (query), L (key)]`

        # Create source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat([1, ymax, 1])  # `[B, L, T]`

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.pos_enc(self.embed(ys_in))  # scaled

        mems = self.init_memory()
        pos_embs = None
        if self.memory_transformer:
            out = self.dropout_emb(out)
            # NOTE: TransformerXL does not use positional encoding in the token embedding
            # adopt zero-centered offset
            pos_idxs = torch.arange(mlen - 1, -ymax - 1, -1.0, dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

        hidden_states = [out]
        xy_aws_layers = []
        for lth, (mem, layer) in enumerate(zip(mems, self.layers)):
            out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout,
                        pos_embs=pos_embs, memory=mem, u=self.u, v=self.v)
            if lth < self.n_layers - 1:
                hidden_states.append(out)
                # NOTE: outputs from the last layer is not used for momory
            # Attention padding
            xy_aws = layer.xy_aws
            if xy_aws is not None and 'mocha' in self.attn_type:
                tgt_mask_v2 = (ys_out != self.pad).unsqueeze(1).unsqueeze(3)  # `[B, 1, L, 1]`
                xy_aws = xy_aws.masked_fill_(tgt_mask_v2.repeat([1, xy_aws.size(1), 1, xmax]) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
                xy_aws_layers.append(xy_aws.clone())
            if not self.training:
                if layer.yy_aws is not None:
                    self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws)
                if layer.xy_aws is not None:
                    self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws)
                if layer.xy_aws_beta is not None:
                    self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta)
                if layer.xy_aws_p_choose is not None:
                    self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose)
                if layer.yy_aws_lm is not None:
                    self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training)
        losses_auxiliary = {}

        # Quantity loss
        losses_auxiliary['loss_quantity'] = 0.
        if 'mocha' in self.attn_type:
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                                 for aws in xy_aws_layers])  # `[B]`
            n_tokens_pred /= len(xy_aws_layers)
            losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, losses_auxiliary
Ejemplo n.º 19
0
    def greedy(self,
               eouts,
               elens,
               max_len_ratio,
               exclude_eos=False,
               idx2token=None,
               refs_id=None,
               speakers=None,
               oracle=False):
        """Greedy decoding in the inference stage (used only for evaluation during training).

        Args:
            eouts (FloatTensor): `[B, T, enc_units]`
            elens (IntTensor): `[B]`
            max_len_ratio (int): maximum sequence length of tokens
            exclude_eos (bool):
            idx2token ():
            refs_id (list):
            speakers (list):
            oracle (bool):
        Returns:
            best_hyps (list): A list of length `[B]`, which contains arrays of size `[L]`
            aw (list): A list of length `[B]`, which contains arrays of size `[L, T]`

        """
        bs, xmax = eouts.size()[:2]

        # Start from <sos> (<eos> in case of the backward decoder)
        ys_all = eouts.new_zeros(bs, 1).fill_(self.eos).long()

        # TODO(hirofumi): Create the source-target mask for batch decoding

        best_hyps_batch = []
        ylens = torch.zeros(bs).int()
        yy_aws_tmp = [None] * bs
        xy_aws_tmp = [None] * bs
        eos_flags = [False] * bs
        for t in range(int(np.floor(xmax * max_len_ratio)) + 1):
            # Create the self-attention mask
            yy_mask = make_pad_mask(ylens + 1,
                                    self.device_id).unsqueeze(1).expand(
                                        bs, t + 1, t + 1)
            yy_mask = yy_mask.unsqueeze(1).expand(bs, self.attn_n_heads, t + 1,
                                                  t + 1)
            subsequent_mask = torch.tril(yy_mask.new_ones(
                (t + 1, t + 1)).byte(),
                                         diagonal=0)
            subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1).expand(
                bs, self.attn_n_heads, t + 1, t + 1)
            yy_mask = yy_mask & subsequent_mask

            # Create the source-target mask
            xmax = eouts.size(1)
            x_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).expand(
                bs, t + 1, xmax)
            y_mask = make_pad_mask(ylens + 1,
                                   self.device_id).unsqueeze(2).expand(
                                       bs, t + 1, xmax)
            xy_mask = (x_mask * y_mask).unsqueeze(1).expand(
                bs, self.attn_n_heads, t + 1, xmax)

            out = self.pos_enc(self.embed(ys_all))
            for l in range(self.n_layers):
                out, yy_aws, xy_aws = self.layers[l](out, yy_mask, eouts,
                                                     xy_mask)
            out = self.norm_out(out)

            # Pick up 1-best
            y = self.output(out).argmax(-1)[:, -1:]
            best_hyps_batch += [y]

            # Count lengths of hypotheses
            for b in range(bs):
                if not eos_flags[b]:
                    if y[b].item() == self.eos:
                        eos_flags[b] = True
                        yy_aws_tmp[b] = yy_aws[b:b + 1]  # TODO: fix this
                        xy_aws_tmp[b] = xy_aws[b:b + 1]
                    ylens[b] += 1
                    # NOTE: include <eos>

            # Break if <eos> is outputed in all mini-bs
            if sum(eos_flags) == bs:
                break

            ys_all = torch.cat([ys_all, y], dim=-1)

        # Concatenate in L dimension
        best_hyps_batch = torch.cat(best_hyps_batch, dim=1)
        # xy_aws_tmp = torch.stack(xy_aws_tmp, dim=0)

        # Convert to numpy
        best_hyps_batch = tensor2np(best_hyps_batch)
        # xy_aws_tmp = tensor2np(xy_aws_tmp)

        # if self.score.attn_n_heads > 1:
        #     xy_aws_tmp = xy_aws_tmp[:, :, :, 0]
        #     # TODO(hirofumi): fix for MHA

        # Truncate by the first <eos> (<sos> in case of the backward decoder)
        if self.bwd:
            # Reverse the order
            best_hyps = [
                best_hyps_batch[b, :ylens[b]][::-1] for b in range(bs)
            ]
            # aws = [xy_aws_tmp[b, :ylens[b]][::-1] for b in range(bs)]
        else:
            best_hyps = [best_hyps_batch[b, :ylens[b]] for b in range(bs)]
            # aws = [xy_aws_tmp[b, :ylens[b]] for b in range(bs)]

        # Exclude <eos> (<sos> in case of the backward decoder)
        if exclude_eos:
            if self.bwd:
                best_hyps = [
                    best_hyps[b][1:] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]
            else:
                best_hyps = [
                    best_hyps[b][:-1] if eos_flags[b] else best_hyps[b]
                    for b in range(bs)
                ]

        # return best_hyps, aws
        return best_hyps, None
Ejemplo n.º 20
0
    def forward_att(self,
                    eouts,
                    elens,
                    ys,
                    ys_hist=[],
                    return_logits=False,
                    teacher_logits=None):
        """Compute XE loss for the Transformer model.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (list): A list of length `[B]`, which contains a list of size `[L]`
            ys_hist (list):
            return_logits (bool): return logits for knowledge distillation
            teacher_logits (FloatTensor): `[B, L, vocab]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity

        """
        bs = eouts.size(0)

        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(eouts, ys, self.eos, self.pad,
                                              self.bwd)

        # Create the self-attention mask
        bs, ytime = ys_in.size()[:2]
        tgt_mask = make_pad_mask(ylens, self.device_id).unsqueeze(1).repeat(
            [1, ytime, 1])
        subsequent_mask = tgt_mask.new_ones(ytime, ytime).byte()
        subsequent_mask = torch.tril(subsequent_mask,
                                     out=subsequent_mask).unsqueeze(0)
        tgt_mask = tgt_mask & subsequent_mask

        # Create the source-target mask
        src_mask = make_pad_mask(elens, self.device_id).unsqueeze(1).repeat(
            [1, ytime, 1])

        out = self.pos_enc(self.embed(ys_in))
        for l in range(self.n_layers):
            out, yy_aws, xy_aws = self.layers[l](out, tgt_mask, eouts,
                                                 src_mask)
            if not self.training:
                setattr(self, 'yy_aws_layer%d' % l, tensor2np(yy_aws))
                setattr(self, 'xy_aws_layer%d' % l, tensor2np(xy_aws))
        logits = self.output(self.norm_out(out))

        # for knowledge distillation
        if return_logits:
            return logits

        # Compute XE sequence loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad,
                                      self.training)

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl
Ejemplo n.º 21
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)
        if not self.training:
            self.data_dict['elens'] = tensor2np(xlens)

        bs, xmax, idim = xs.size()
        xs = self.pos_enc(xs)
        if self.chunk_size_left > 0:
            # Time-restricted self-attention for streaming models
            cs_l = self.chunk_size_left
            cs_c = self.chunk_size_cur
            cs_r = self.chunk_size_right
            xs_chunks = []
            xx_aws = [[] for l in range(self.n_layers)]
            xs_pad = torch.cat([
                xs.new_zeros(bs, cs_l, idim), xs,
                xs.new_zeros(bs, cs_r, idim)
            ],
                               dim=1)
            # TODO: remove right padding
            for t in range(cs_l, cs_l + xmax, self.chunk_size_cur):
                xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r]
                for l, layer in enumerate(self.layers):
                    xs_chunk, xx_aws_chunk = layer(xs_chunk, None)  # no mask
                    xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c,
                                                  cs_l:cs_l + cs_c])
                xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c])
            xs = torch.cat(xs_chunks, dim=1)[:, :xmax]
            if not self.training:
                for l in range(self.n_layers):
                    self.aws_dict['xx_aws_layer%d' % l] = tensor2np(
                        torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax])
        else:
            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            for l, layer in enumerate(self.layers):
                xs, xx_aws = layer(xs, xx_mask)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % l] = tensor2np(xx_aws)

                # Pick up outputs in the sub task before the projection layer
                if l == self.n_layers_sub1 - 1:
                    xs_sub1 = self.layer_sub1(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub1 = self.norm_out_sub1(xs_sub1)
                    if self.bridge_sub1 is not None:
                        xs_sub1 = self.bridge_sub1(xs_sub1)
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if l == self.n_layers_sub2 - 1:
                    xs_sub2 = self.layer_sub2(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub2 = self.norm_out_sub2(xs_sub2)
                    if self.bridge_sub2 is not None:
                        xs_sub2 = self.bridge_sub2(xs_sub2)
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
Ejemplo n.º 22
0
    def forward(self,
                xs,
                xlens,
                task,
                streaming=False,
                lookback=False,
                lookahead=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (InteTensor): `[B]` (on CPU)
            task (str): ys/ys_sub1/ys_sub2
            streaming (bool): streaming encoding
            lookback (bool): truncate leftmost frames for lookback in CNN context
            lookahead (bool): truncate rightmost frames for lookahead in CNN context
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (InteTensor): `[B]` (on CPU)

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        N_l = self.chunk_size_left
        N_c = self.chunk_size_current
        N_r = self.chunk_size_right
        bs, xmax, idim = xs.size()
        n_chunks = 0
        clamp_len = self.clamp_len

        if self.latency_controlled:
            if self.streaming_type == 'reshape':
                xs = chunkwise(xs, N_l, N_c,
                               N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
            elif self.streaming_type == 'mask':
                # xs = chunkwise(xs, N_l, N_c, N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
                xs = chunkwise(xs, 0, N_c, 0)  # `[B * n_chunks, N_c, idim]`
            n_chunks = xs.size(0) // bs

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)
            N_l = max(0, N_l // self.conv.subsampling_factor)
            N_c = N_c // self.conv.subsampling_factor
            N_r = N_r // self.conv.subsampling_factor
            clamp_len = clamp_len // self.conv.subsampling_factor

        if self.streaming_type == 'mask':
            # Extract the center region
            emax = xlens.max().item()
            xs = xs.contiguous().view(
                bs, -1, xs.size(2))[:, :emax]  # `[B, emax, d_model]`

        if self.latency_controlled:
            # streaming Transformer encoder
            emax = xlens.max().item()

            pos_embs = None
            if self.pe_type in ['relative', 'relative_xl']:
                xs = xs * self.scale
                pos_embs = self.pos_emb(xs, zero_center_offset=True
                                        )  # NOTE: no clamp_len for streaming
            else:
                xs = self.pos_enc(xs, scale=True)

            if self.streaming_type == 'reshape':
                xx_mask_first = None
                xx_mask = None  # NOTE: no mask to avoid masking all frames in a chunk
            elif self.streaming_type == 'mask':
                xx_mask_first, xx_mask = time_restricted_mask(
                    xs, xlens, N_l, N_c, N_r, n_chunks)

            for lth, layer in enumerate(self.layers):
                xs = layer(xs,
                           xx_mask if lth >= 1 else xx_mask_first,
                           pos_embs=pos_embs,
                           u_bias=self.u_bias,
                           v_bias=self.v_bias)
                if not self.training:
                    if self.streaming_type == 'reshape':
                        n_heads = layer.xx_aws.size(1)
                        xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c,
                                              N_l:N_l + N_c]
                        xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c)
                        xx_aws_center = xx_aws.new_zeros(
                            bs, n_heads, emax, emax)
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            emax_chunk = xx_aws_center[:, :, offset:offset +
                                                       N_c].size(2)
                            xx_aws_chunk = xx_aws[:, chunk_idx, :, :
                                                  emax_chunk, :emax_chunk]
                            xx_aws_center[:, :, offset:offset + N_c,
                                          offset:offset + N_c] = xx_aws_chunk
                        self.aws_dict['xx_aws_layer%d' %
                                      lth] = tensor2np(xx_aws_center)
                    elif self.streaming_type == 'mask':
                        self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                            layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    emax = xlens.max().item()
                    N_l = max(0, N_l // self.subsample[lth].subsampling_factor)
                    N_c = N_c // self.subsample[lth].subsampling_factor
                    N_r = N_r // self.subsample[lth].subsampling_factor
                    if self.pe_type in ['relative', 'relative_xl']:
                        # Create sinusoidal positional embeddings for relative positional encoding
                        pos_embs = self.pos_emb(
                            xs, zero_center_offset=True
                        )  # NOTE: no clamp_len for streaming
                    if self.streaming_type == 'mask':
                        _, xx_mask = time_restricted_mask(
                            xs, xlens, N_l, N_c, N_r, n_chunks)

            # Extract the center region
            if self.streaming_type == 'reshape':
                xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
                xs = xs.contiguous().view(bs, -1, xs.size(2))
                xs = xs[:, :emax]

        else:
            if self.pe_type in ['relative', 'relative_xl']:
                xs = xs * self.scale
                # Create sinusoidal positional embeddings for relative positional encoding
                pos_embs = self.pos_emb(xs,
                                        clamp_len=clamp_len,
                                        zero_center_offset=True)
            else:
                xs = self.pos_enc(xs, scale=True)
                pos_embs = None

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens.to(self.device)).unsqueeze(1).repeat(
                [1, xs.size(1), 1])

            for lth, layer in enumerate(self.layers):
                xs = layer(xs,
                           xx_mask,
                           pos_embs=pos_embs,
                           u_bias=self.u_bias,
                           v_bias=self.v_bias)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub1')
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub2')
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    # Create the self-attention mask
                    xx_mask = make_pad_mask(xlens.to(
                        self.device)).unsqueeze(1).repeat([1, xs.size(1), 1])
                    if self.pe_type in ['relative', 'relative_xl']:
                        # Create sinusoidal positional embeddings for relative positional encoding
                        clamp_len = clamp_len // self.subsample[
                            lth].subsampling_factor
                        pos_embs = self.pos_emb(xs,
                                                clamp_len=clamp_len,
                                                zero_center_offset=True)

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
Ejemplo n.º 23
0
    def forward(self, eouts, elens, ylens=None, mode='parallel'):
        """Forward pass.

        Args:
            eouts (FloatTensor): `[B, T, enc_dim]`
            elens (IntTensor): `[B]`
            ylens (IntTensor): `[B]`
            mode (str): parallel/incremental
        Returns:
            cv (FloatTensor): `[B, L, enc_dim]`
            alpha (FloatTensor): `[B, T]`
            aws (FloatTensor): `[B, L, T]`

        """
        bs, xmax, enc_dim = eouts.size()

        # 1d conv
        conv_feat = self.conv1d(eouts.transpose(2, 1)).transpose(
            2, 1)  # `[B, T, enc_dim]`
        conv_feat = torch.relu(self.norm(conv_feat))
        alpha = torch.sigmoid(self.proj(conv_feat)).squeeze(2)  # `[B, T]`

        # normalization
        if mode == 'parallel':
            # padding
            assert ylens is not None
            device = eouts.device
            ylens = ylens.to(device)
            mask = make_pad_mask(elens.to(device))
            alpha = alpha.clone().masked_fill_(mask == 0, 0)

            alpha_norm = alpha / alpha.sum(
                1, keepdim=True) * ylens.float().unsqueeze(1)
            ymax = int(ylens.max().item())
        elif mode == 'incremental':
            alpha_norm = alpha  # infernece time
            ymax = 1
            if bs > 1:
                raise NotImplementedError('Batch mode is not supported.')
                # TODO(hirofumi0810): support batch mode
        else:
            raise ValueError(mode)

        cv = eouts.new_zeros(bs, ymax + 1, enc_dim)
        aws = eouts.new_zeros(bs, ymax + 1, xmax)
        n_tokens = torch.zeros(bs, dtype=torch.int64)
        state = eouts.new_zeros(bs, self.enc_dim)
        alpha_accum = eouts.new_zeros(bs)
        for j in range(xmax):
            alpha_accum_prev = alpha_accum
            alpha_accum += alpha_norm[:, j]

            if mode == 'parallel' and (alpha_accum >= self.beta).sum() == 0:
                # No boundary is located in all utterances in mini-batch
                # Carry over to the next frame
                state += alpha_norm[:, j, None] * eouts[:, j]
                aws[:, n_tokens, j] += alpha_norm[:, j]
            else:
                for b in range(bs):
                    # skip the padding region
                    if j > elens[b] - 1:
                        continue

                    # skip all-fired utterance
                    if mode == 'parallel' and n_tokens[b].item() >= ylens[b]:
                        continue

                    if alpha_accum[b] < self.beta:
                        # No boundary is located
                        # Carry over to the next frame
                        state[b] += alpha_norm[b, j, None] * eouts[b, j]
                        aws[b, n_tokens[b], j] += alpha_norm[b, j]

                        # tail handling
                        if mode == 'incremental' and j == elens[b] - 1:
                            if alpha_accum[b] >= 0.5:
                                n_tokens[b] += 1
                                cv[b, n_tokens[b]] = state[b]
                            break
                    else:
                        # A boundary is located
                        ak1 = 1 - alpha_accum_prev[b]
                        ak2 = alpha_norm[b, j] - ak1
                        cv[b, n_tokens[b]] = state[b] + ak1 * eouts[b, j]
                        aws[b, n_tokens[b], j] += ak1
                        n_tokens[b] += 1
                        # Carry over to the next frame
                        state[b] = ak2 * eouts[b, j]
                        alpha_accum[b] = ak2
                        aws[b, n_tokens[b], j] += ak2

                        if mode == 'incremental':
                            break

                if mode == 'incremental' and n_tokens[0] >= 1:
                    break
                    # TODO(hirofumi0810): support batch mode

        # truncate
        cv = cv[:, :ymax]
        aws = aws[:, :ymax]

        return cv, alpha, aws
Ejemplo n.º 24
0
    def align(self, logits, elens, ys, ylens):
        bs, xmax, vocab = logits.size()

        # zero padding
        device_id = torch.cuda.device_of(logits).idx
        mask = make_pad_mask(elens, device_id)
        mask = mask.unsqueeze(2).repeat([1, 1, vocab])
        logits = logits.masked_fill_(mask == 0, self.log0)
        log_probs = torch.log_softmax(logits,
                                      dim=-1).transpose(0,
                                                        1)  # `[T, B, vocab]`

        path = _label_to_path(ys, self.blank)
        path_lens = 2 * ylens.long() + 1

        ymax = ys.size(1)
        max_path_len = path.size(1)
        assert ys.size() == (bs, ymax), ys.size()
        assert path.size() == (bs, ymax * 2 + 1)

        alpha = log_probs.new_zeros(bs, max_path_len).fill_(self.log0)
        alpha[:, 0] = LOG_1
        beta = alpha.clone()
        gamma = alpha.clone()

        batch_index = torch.arange(bs, dtype=torch.int64).unsqueeze(1)
        seq_index = torch.arange(xmax,
                                 dtype=torch.int64).unsqueeze(1).unsqueeze(2)
        log_probs_fwd_bwd = log_probs[seq_index, batch_index, path]

        # forward algorithm
        for t in range(xmax):
            alpha = self._computes_transition(alpha, path, path_lens,
                                              log_probs_fwd_bwd[t],
                                              log_probs[t])

        # backward algorithm
        r_path = _flip_path(path, path_lens)
        log_probs_inv = _flip_label_probability(
            log_probs, elens.long())  # `[T, B, vocab]`
        log_probs_fwd_bwd = _flip_path_probability(
            log_probs_fwd_bwd, elens.long(), path_lens)  # `[T, B, 2*L+1]`
        for t in range(xmax):
            beta = self._computes_transition(beta, r_path, path_lens,
                                             log_probs_fwd_bwd[t],
                                             log_probs_inv[t])

        # pick up the best CTC path
        best_lattices = log_probs.new_zeros((bs, xmax), dtype=torch.int64)

        # forward algorithm
        log_probs_fwd_bwd = _flip_path_probability(log_probs_fwd_bwd,
                                                   elens.long(), path_lens)
        for t in range(xmax):
            gamma = self._computes_transition(gamma,
                                              path,
                                              path_lens,
                                              log_probs_fwd_bwd[t],
                                              log_probs[t],
                                              skip_accum=True)

            # select paths where gamma is valid
            log_probs_fwd_bwd[t] = log_probs_fwd_bwd[t].masked_fill_(
                gamma == self.log0, self.log0)

            # pick up the best lattice
            offsets = log_probs_fwd_bwd[t].argmax(1)
            for b in range(bs):
                if t <= elens[b] - 1:
                    token_idx = path[b, offsets[b]]
                    best_lattices[b, t] = token_idx

            # remove the rest of paths
            gamma = log_probs.new_zeros(bs, max_path_len).fill_(self.log0)
            for b in range(bs):
                gamma[b, offsets[b]] = LOG_1

        # pick up trigger points
        trigger_lattices = torch.zeros((bs, xmax), dtype=torch.int64)
        trigger_points = log_probs.new_zeros((bs, ymax + 1),
                                             dtype=torch.int32)  # +1 for <eos>
        for b in range(bs):
            n_triggers = 0
            trigger_points[b, ylens[b]] = elens[b] - 1  # for <eos>
            for t in range(elens[b]):
                token_idx = best_lattices[b, t]
                if token_idx == self.blank:
                    continue
                if not (t == 0 or token_idx != best_lattices[b, t - 1]):
                    continue

                # NOTE: select the most left trigger points
                trigger_lattices[b, t] = token_idx
                trigger_points[b, n_triggers] = t
                n_triggers += 1

        # print(trigger_points[0])
        # print(trigger_lattices[0])
        # print(ys[0])

        assert ylens.sum() == (trigger_lattices != 0).sum()
        return trigger_points
Ejemplo n.º 25
0
    def forward_att(self, eouts, elens, ys, trigger_points=None):
        """Compute XE loss for the Transformer decoder.

        Args:
            eouts (FloatTensor): `[B, T, d_model]`
            elens (IntTensor): `[B]`
            ys (List): length `[B]`, each of which contains a list of size `[L]`
            trigger_points (IntTensor): `[B, L]`
        Returns:
            loss (FloatTensor): `[1]`
            acc (float): accuracy for token prediction
            ppl (float): perplexity
            losses_auxiliary (dict):

        """
        losses_auxiliary = {}

        # Append <sos> and <eos>
        ys_in, ys_out, ylens = append_sos_eos(ys, self.eos, self.eos, self.pad, self.device, self.bwd)
        if not self.training:
            self.data_dict['elens'] = tensor2np(elens)
            self.data_dict['ylens'] = tensor2np(ylens)
            self.data_dict['ys'] = tensor2np(ys_out)

        # Create target self-attention mask
        bs, ymax = ys_in.size()[:2]
        tgt_mask = (ys_out != self.pad).unsqueeze(1).repeat([1, ymax, 1])
        causal_mask = tgt_mask.new_ones(ymax, ymax, dtype=tgt_mask.dtype)
        causal_mask = torch.tril(causal_mask).unsqueeze(0)
        tgt_mask = tgt_mask & causal_mask  # `[B, L (query), L (key)]`

        # Create source-target mask
        src_mask = make_pad_mask(elens.to(self.device)).unsqueeze(1).repeat([1, ymax, 1])  # `[B, L, T]`

        # Create attention padding mask for quantity loss
        if self.attn_type == 'mocha':
            attn_mask = (ys_out != self.pad).unsqueeze(1).unsqueeze(3)  # `[B, 1, L, 1]`
        else:
            attn_mask = None

        # external LM integration
        lmout = None
        if self.lm is not None:
            self.lm.eval()
            with torch.no_grad():
                lmout, lmstate, _ = self.lm.predict(ys_in, None)
            lmout = self.lm_output_proj(lmout)

        out = self.pos_enc(self.embed_token_id(ys_in), scale=True)  # scaled + dropout

        xy_aws_layers = []
        xy_aws = None
        for lth, layer in enumerate(self.layers):
            out = layer(out, tgt_mask, eouts, src_mask, mode='parallel', lmout=lmout)
            # Attention padding
            xy_aws = layer.xy_aws
            if xy_aws is not None and self.attn_type == 'mocha':
                xy_aws_masked = xy_aws.masked_fill_(attn_mask.expand_as(xy_aws) == 0, 0)
                # NOTE: attention padding is quite effective for quantity loss
                xy_aws_layers.append(xy_aws_masked.clone())
            if not self.training:
                self.aws_dict['yy_aws_layer%d' % lth] = tensor2np(layer.yy_aws)
                self.aws_dict['xy_aws_layer%d' % lth] = tensor2np(layer.xy_aws)
                self.aws_dict['xy_aws_beta_layer%d' % lth] = tensor2np(layer.xy_aws_beta)
                self.aws_dict['xy_aws_p_choose%d' % lth] = tensor2np(layer.xy_aws_p_choose)
                self.aws_dict['yy_aws_lm_layer%d' % lth] = tensor2np(layer.yy_aws_lm)
        logits = self.output(self.norm_out(out))

        # Compute XE loss (+ label smoothing)
        loss, ppl = cross_entropy_lsm(logits, ys_out, self.lsm_prob, self.pad, self.training)

        # Quantity loss
        losses_auxiliary['loss_quantity'] = 0.
        if self.attn_type == 'mocha':
            # Average over all heads across all layers
            n_tokens_ref = tgt_mask[:, -1, :].sum(1).float()  # `[B]`
            # NOTE: count <eos> tokens
            n_tokens_pred = sum([torch.abs(aws.sum(3).sum(2).sum(1) / aws.size(1))
                                 for aws in xy_aws_layers])  # `[B]`
            n_tokens_pred /= len(xy_aws_layers)
            losses_auxiliary['loss_quantity'] = torch.mean(torch.abs(n_tokens_pred - n_tokens_ref))

        # Compute token-level accuracy in teacher-forcing
        acc = compute_accuracy(logits, ys_out, self.pad)

        return loss, acc, ppl, losses_auxiliary
Ejemplo n.º 26
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)

        bs, xmax, idim = xs.size()
        xs = self.pos_enc(xs)
        if self.chunk_size_left > 0:
            # Time-restricted self-attention for streaming models
            cs_l = self.chunk_size_left
            cs_c = self.chunk_size_current
            cs_r = self.chunk_size_right
            hop_size = self.chunk_size_current
            xs_chunks = []
            xx_aws = [[] for l in range(self.n_layers)]
            xs_pad = torch.cat([
                xs.new_zeros(bs, cs_l, idim), xs,
                xs.new_zeros(bs, cs_r, idim)
            ],
                               dim=1)
            # TODO: remove right padding
            for t in range(cs_l, cs_l + xmax, hop_size):
                xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r]
                for l in range(self.n_layers):
                    xs_chunk, xx_aws_chunk = self.layers[l](xs_chunk,
                                                            None)  # no mask
                    xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c,
                                                  cs_l:cs_l + cs_c])
                xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c])
            xs = torch.cat(xs_chunks, dim=1)[:, :xmax]
            if not self.training:
                for l in range(self.n_layers):
                    setattr(
                        self, 'xx_aws_layer%d' % l,
                        tensor2np(
                            torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax]))
        else:
            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            for l in range(self.n_layers):
                xs, xx_aws = self.layers[l](xs, xx_mask)
                if not self.training:
                    setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws))
        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        eouts['ys']['xs'] = xs
        eouts['ys']['xlens'] = xlens
        return eouts