示例#1
0
    def forward(self,
                xs,
                xlens,
                task,
                streaming=False,
                lookback=False,
                lookahead=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (InteTensor): `[B]` (on CPU)
            task (str): ys/ys_sub1/ys_sub2
            streaming (bool): streaming encoding
            lookback (bool): truncate leftmost frames for lookback in CNN context
            lookahead (bool): truncate rightmost frames for lookahead in CNN context
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (InteTensor): `[B]` (on CPU)

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        bs, xmax = xs.size()[:2]
        n_chunks = 0
        unidir = self.unidir
        lc_bidir = self.lc_bidir
        N_l, N_c, N_r = self.chunk_size_left, self.chunk_size_current, self.chunk_size_right

        if streaming and self.streaming_type == 'mask':
            assert xmax <= N_c
        elif streaming and self.streaming_type == 'reshape':
            assert xmax <= (N_l + N_c + N_r)

        if lc_bidir:
            if self.streaming_type == 'mask' and not streaming:
                xs = chunkwise(xs, 0, N_c, 0,
                               padding=True)  # `[B * n_chunks, N_c, idim]`
                # NOTE: CNN consumes inputs in the current chunk to avoid extra lookahead latency
                # That is, CNN outputs are independent on chunk boundary
            elif self.streaming_type == 'reshape':
                xs = chunkwise(xs, N_l, N_c, N_r, padding=not streaming
                               )  # `[B * n_chunks, N_l+N_c+N_r, idim]`
            n_chunks = xs.size(0) // bs
            assert bs * n_chunks == xs.size(0)
            if streaming:
                assert n_chunks == 1, xs.size()

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs,
                                  xlens,
                                  lookback=False if lc_bidir else lookback,
                                  lookahead=False if lc_bidir else lookahead)
            # NOTE: CNN lookahead surpassing a chunk is not allowed in chunkwise processing
            N_l = max(0, N_l // self.conv.subsampling_factor)
            N_c = N_c // self.conv.subsampling_factor
            N_r = N_r // self.conv.subsampling_factor

        if lc_bidir:
            # Do nothing in the streaming mode
            if self.streaming_type == 'mask' and not streaming:
                # back to the original shape (during training only)
                xs = xs.contiguous().view(
                    bs, -1,
                    xs.size(2))[:, :xlens.max()]  # `[B, emax, d_model]`
        elif streaming:
            xs = xs[:, :xlens.max()]  # for unidirectional

        if self.enc_type == 'conv':
            eouts['ys']['xs'] = xs
            eouts['ys']['xlens'] = xlens
            return eouts

        if not streaming:
            self.reset_cache()
        n_hist = self.cache[0]['input_san'].size(
            1) if streaming and self.cache[0] is not None else 0

        # positional encoding
        if self.pe_type in ['relative', 'relative_xl']:
            xs = xs * self.scale  # NOTE: first layer only
            rel_pos_embs = self.pos_emb(xs, mlen=n_hist)
        else:
            xs = self.pos_enc(xs, scale=True, offset=max(0, n_hist))
            rel_pos_embs = None

        new_cache = [None] * self.n_layers
        if lc_bidir:
            # chunkwise streaming encoder
            if self.streaming_type == 'reshape':
                xx_mask = None  # NOTE: no mask to avoid masking all frames in a chunk
            elif self.streaming_type == 'mask':
                if streaming:
                    n_chunks = math.ceil((xlens.max().item() + n_hist) / N_c)
                xx_mask = make_chunkwise_san_mask(xs, xlens + n_hist, N_l, N_c,
                                                  n_chunks)

            for lth, layer in enumerate(self.layers):
                xs, cache = layer(xs,
                                  xx_mask,
                                  cache=self.cache[lth],
                                  pos_embs=rel_pos_embs,
                                  u_bias=self.u_bias,
                                  v_bias=self.v_bias)
                if self.streaming_type == 'mask':
                    new_cache[lth] = cache
                if not self.training and not streaming:
                    if self.streaming_type == 'reshape':
                        n_heads = layer.xx_aws.size(1)
                        xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c,
                                              N_l:N_l + N_c]
                        xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c)
                        emax = xlens.max().item()
                        xx_aws_center = xx_aws.new_zeros(
                            bs, n_heads, emax, emax)
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            emax_chunk = xx_aws_center[:, :, offset:offset +
                                                       N_c].size(2)
                            xx_aws_chunk = xx_aws[:, chunk_idx, :, :
                                                  emax_chunk, :emax_chunk]
                            xx_aws_center[:, :, offset:offset + N_c,
                                          offset:offset + N_c] = xx_aws_chunk
                        self.aws_dict['xx_aws_layer%d' %
                                      lth] = tensor2np(xx_aws_center)
                    elif self.streaming_type == 'mask':
                        self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                            layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    N_l = max(0, N_l // self.subsample[lth].factor)
                    N_c = N_c // self.subsample[lth].factor
                    N_r = N_r // self.subsample[lth].factor
                    if self.pe_type in ['relative', 'relative_xl']:
                        rel_pos_embs = self.pos_emb(xs)
                    if self.streaming_type == 'mask':
                        xx_mask = make_chunkwise_san_mask(
                            xs, xlens, N_l, N_c, n_chunks)

            # Extract the center region
            if self.streaming_type == 'reshape':
                xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
                xs = xs.contiguous().view(bs, -1, xs.size(2))
                xs = xs[:, :xlens.max()]

        else:
            xx_mask = make_san_mask(xs, xlens + n_hist, unidir,
                                    self.lookaheads[0])
            for lth, layer in enumerate(self.layers):
                xs, cache = layer(xs,
                                  xx_mask,
                                  cache=self.cache[lth],
                                  pos_embs=rel_pos_embs,
                                  u_bias=self.u_bias,
                                  v_bias=self.v_bias)
                new_cache[lth] = cache
                if not self.training and not streaming:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.sub_module(xs, xx_mask, lth, rel_pos_embs,
                                              'sub1')
                    xlens_sub1 = xlens.clone()
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens_sub1
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.sub_module(xs, xx_mask, lth, rel_pos_embs,
                                              'sub2')
                    xlens_sub2 = xlens.clone()
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens_sub2
                        return eouts

                if lth < len(self.layers) - 1:
                    if self.subsample is not None and self.subsample[
                            lth].factor > 1:
                        xs, xlens = self.subsample[lth](xs, xlens)
                        n_hist = self.cache[lth + 1]['input_san'].size(
                            1) if streaming and self.cache[
                                lth + 1] is not None else 0
                        if self.pe_type in ['relative', 'relative_xl']:
                            rel_pos_embs = self.pos_emb(xs, mlen=n_hist)
                        xx_mask = make_san_mask(xs, xlens + n_hist, unidir,
                                                self.lookaheads[lth + 1])
                    elif self.lookaheads[lth] != self.lookaheads[lth + 1]:
                        xx_mask = make_san_mask(xs, xlens + n_hist, unidir,
                                                self.lookaheads[lth + 1])

        xs = self.norm_out(xs)

        if streaming:
            self.cache = new_cache

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1'][
                'xlens'] = xs_sub1, xlens_sub1
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2'][
                'xlens'] = xs_sub2, xlens_sub2
        return eouts
示例#2
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        N_l = self.chunk_size_left
        N_c = self.chunk_size_current
        N_r = self.chunk_size_right

        bs, xmax, idim = xs.size()

        if self.latency_controlled:
            xs = chunkwise(xs, N_l, N_c, N_r)

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)

        if not self.training:
            self.data_dict['elens'] = tensor2np(xlens)

        if self.latency_controlled:
            # streaming Conformer encoder
            _N_l = max(0, N_l // self.subsampling_factor)
            _N_c = N_c // self.subsampling_factor

            n_chunks = math.ceil(xs.size(0) / bs)
            emax = math.ceil(xmax / self.subsampling_factor)

            xs = xs * self.scale
            pos_idxs = torch.arange(xs.size(1) - 1,
                                    -1,
                                    -1.0,
                                    dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

            xx_mask = None  # NOTE: no mask
            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    n_heads = layer.xx_aws.size(1)
                    xx_aws = layer.xx_aws[:, :, _N_l:_N_l + _N_c,
                                          _N_l:_N_l + _N_c]
                    xx_aws = xx_aws.view(bs, n_chunks, n_heads, _N_c, _N_c)
                    xx_aws_center = xx_aws.new_zeros(bs, n_heads, emax, emax)
                    for chunk_idx in range(n_chunks):
                        offset = chunk_idx * _N_c
                        emax_blc = xx_aws_center[:, :,
                                                 offset:offset + _N_c].size(2)
                        xx_aws_chunk = xx_aws[:, chunk_idx, :, :emax_blc, :
                                              emax_blc]
                        xx_aws_center[:, :, offset:offset + _N_c,
                                      offset:offset + _N_c] = xx_aws_chunk
                    self.aws_dict['xx_aws_layer%d' %
                                  lth] = tensor2np(xx_aws_center)

            # Extract the center region
            xs = xs[:, _N_l:_N_l + _N_c]  # `[B * n_chunks, _N_c, d_model]`
            xs = xs.contiguous().view(bs, -1, xs.size(2))
            xs = xs[:, :emax]

        else:
            bs, xmax, idim = xs.size()
            xs = xs * self.scale

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            pos_idxs = torch.arange(xmax - 1, -1, -1.0, dtype=torch.float)
            pos_embs = self.pos_emb(pos_idxs, self.device_id)

            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.layer_sub1(
                        xs, xx_mask, pos_embs=pos_embs
                    ) if self.task_specific_layer else xs.clone()
                    xs_sub1 = self.norm_out_sub1(xs_sub1)
                    if self.bridge_sub1 is not None:
                        xs_sub1 = self.bridge_sub1(xs_sub1)
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.layer_sub2(
                        xs, xx_mask, pos_embs=pos_embs
                    ) if self.task_specific_layer else xs.clone()
                    xs_sub2 = self.norm_out_sub2(xs_sub2)
                    if self.bridge_sub2 is not None:
                        xs_sub2 = self.bridge_sub2(xs_sub2)
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
示例#3
0
    def forward(self,
                xs,
                xlens,
                task,
                streaming=False,
                lookback=False,
                lookahead=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (InteTensor): `[B]` (on CPU)
            task (str): ys/ys_sub1/ys_sub2
            streaming (bool): streaming encoding
            lookback (bool): truncate leftmost frames for lookback in CNN context
            lookahead (bool): truncate rightmost frames for lookahead in CNN context
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (InteTensor): `[B]` (on CPU)

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        N_l = self.chunk_size_left
        N_c = self.chunk_size_current
        N_r = self.chunk_size_right
        bs = xs.size(0)
        n_chunks = 0
        clamp_len = self.clamp_len

        if self.latency_controlled:
            if self.streaming_type == 'reshape':
                xs = chunkwise(xs, N_l, N_c,
                               N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
            elif self.streaming_type == 'mask':
                # xs = chunkwise(xs, N_l, N_c, N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
                xs = chunkwise(xs, 0, N_c, 0)  # `[B * n_chunks, N_c, idim]`
            n_chunks = xs.size(0) // bs

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)
            N_l = max(0, N_l // self.conv.subsampling_factor)
            N_c = N_c // self.conv.subsampling_factor
            N_r = N_r // self.conv.subsampling_factor
            clamp_len = clamp_len // self.conv.subsampling_factor

        if self.streaming_type == 'mask':
            # Extract the center region
            emax = xlens.max().item()
            xs = xs.contiguous().view(
                bs, -1, xs.size(2))[:, :emax]  # `[B, emax, d_model]`

        if self.latency_controlled:
            # streaming encoder
            emax = xlens.max().item()

            pos_embs = None
            if self.pe_type in ['relative', 'relative_xl']:
                xs = xs * self.scale
                pos_embs = self.pos_emb(xs, zero_center_offset=True
                                        )  # NOTE: no clamp_len for streaming
            else:
                xs = self.pos_enc(xs, scale=True)

            if self.streaming_type == 'reshape':
                xx_mask_first = None
                xx_mask = None  # NOTE: no mask to avoid masking all frames in a chunk
            elif self.streaming_type == 'mask':
                xx_mask_first, xx_mask = make_time_restricted_san_mask(
                    xs, xlens, N_l, N_c, N_r, n_chunks)

            for lth, layer in enumerate(self.layers):
                xs = layer(xs,
                           xx_mask if lth >= 1 else xx_mask_first,
                           pos_embs=pos_embs,
                           u_bias=self.u_bias,
                           v_bias=self.v_bias)
                if not self.training:
                    if self.streaming_type == 'reshape':
                        n_heads = layer.xx_aws.size(1)
                        xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c,
                                              N_l:N_l + N_c]
                        xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c)
                        xx_aws_center = xx_aws.new_zeros(
                            bs, n_heads, emax, emax)
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            emax_chunk = xx_aws_center[:, :, offset:offset +
                                                       N_c].size(2)
                            xx_aws_chunk = xx_aws[:, chunk_idx, :, :
                                                  emax_chunk, :emax_chunk]
                            xx_aws_center[:, :, offset:offset + N_c,
                                          offset:offset + N_c] = xx_aws_chunk
                        self.aws_dict['xx_aws_layer%d' %
                                      lth] = tensor2np(xx_aws_center)
                    elif self.streaming_type == 'mask':
                        self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                            layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    emax = xlens.max().item()
                    N_l = max(0, N_l // self.subsample[lth].subsampling_factor)
                    N_c = N_c // self.subsample[lth].subsampling_factor
                    N_r = N_r // self.subsample[lth].subsampling_factor
                    if self.pe_type in ['relative', 'relative_xl']:
                        # Create sinusoidal positional embeddings for relative positional encoding
                        pos_embs = self.pos_emb(
                            xs, zero_center_offset=True
                        )  # NOTE: no clamp_len for streaming
                    if self.streaming_type == 'mask':
                        _, xx_mask = make_time_restricted_san_mask(
                            xs, xlens, N_l, N_c, N_r, n_chunks)

            # Extract the center region
            if self.streaming_type == 'reshape':
                xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
                xs = xs.contiguous().view(bs, -1, xs.size(2))
                xs = xs[:, :emax]

        else:
            if self.pe_type in ['relative', 'relative_xl']:
                xs = xs * self.scale
                # Create sinusoidal positional embeddings for relative positional encoding
                pos_embs = self.pos_emb(xs,
                                        clamp_len=clamp_len,
                                        zero_center_offset=True)
            else:
                xs = self.pos_enc(xs, scale=True)
                pos_embs = None

            xx_mask = make_san_mask(xs, xlens, self.unidirectional)
            for lth, layer in enumerate(self.layers):
                xs = layer(xs,
                           xx_mask,
                           pos_embs=pos_embs,
                           u_bias=self.u_bias,
                           v_bias=self.v_bias)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub1')
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub2')
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    xx_mask = make_san_mask(xs, xlens, self.unidirectional)
                    if self.pe_type in ['relative', 'relative_xl']:
                        # Create sinusoidal positional embeddings for relative positional encoding
                        clamp_len = clamp_len // self.subsample[
                            lth].subsampling_factor
                        pos_embs = self.pos_emb(xs,
                                                clamp_len=clamp_len,
                                                zero_center_offset=True)

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
示例#4
0
    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (InteTensor): `[B]` (on CPU)
            task (str): ys/ys_sub1/ys_sub2
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (InteTensor): `[B]` (on CPU)

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        N_l = self.chunk_size_left
        N_c = self.chunk_size_current
        N_r = self.chunk_size_right
        bs, xmax, idim = xs.size()
        n_chunks = 0

        if self.latency_controlled:
            if self.lc_type == 'reshape':
                xs = chunkwise(xs, N_l, N_c,
                               N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
            elif self.lc_type == 'mask':
                # xs = chunkwise(xs, N_l, N_c, N_r)  # `[B * n_chunks, N_l+N_c+N_r, idim]`
                xs = chunkwise(xs, 0, N_c, 0)  # `[B * n_chunks, N_c, idim]`
            else:
                raise ValueError
            n_chunks = xs.size(0) // bs

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks
            xs, xlens = self.conv(xs, xlens)
            N_l = max(0, N_l // self.conv.subsampling_factor)
            N_c = N_c // self.conv.subsampling_factor
            N_r = N_r // self.conv.subsampling_factor

        if self.lc_type == 'mask':
            # Extract the center region
            emax = xlens.max().item()
            # xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
            xs = xs.contiguous().view(bs, -1, xs.size(2))
            xs = xs[:, :emax]  # `[B, emax, d_model]`

        if self.latency_controlled:
            # streaming Transformer encoder
            emax = xlens.max().item()

            pos_embs = None
            if self.pe_type == 'relative':
                xs = xs * self.scale
                pos_idxs = torch.arange(xs.size(1) - 1,
                                        -1,
                                        -1.0,
                                        dtype=torch.float,
                                        device=self.device)
                pos_embs = self.pos_emb(pos_idxs)
            else:
                xs = self.pos_enc(xs, scale=True)

            xx_mask = None  # NOTE: no mask to avoid all masked region
            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    if self.lc_type == 'reshape':
                        n_heads = layer.xx_aws.size(1)
                        xx_aws = layer.xx_aws[:, :, N_l:N_l + N_c,
                                              N_l:N_l + N_c]
                        xx_aws = xx_aws.view(bs, n_chunks, n_heads, N_c, N_c)
                        xx_aws_center = xx_aws.new_zeros(
                            bs, n_heads, emax, emax)
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            emax_chunk = xx_aws_center[:, :, offset:offset +
                                                       N_c].size(2)
                            xx_aws_chunk = xx_aws[:, chunk_idx, :, :
                                                  emax_chunk, :emax_chunk]
                            xx_aws_center[:, :, offset:offset + N_c,
                                          offset:offset + N_c] = xx_aws_chunk
                    elif self.lc_type == 'mask':
                        self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                            layer.xx_aws)
                    else:
                        raise ValueError
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    emax = xlens.max().item()
                    N_l = max(0, N_l // self.subsample[lth].subsampling_factor)
                    N_c = N_c // self.subsample[lth].subsampling_factor
                    N_r = N_r // self.subsample[lth].subsampling_factor
                    if self.lc_type == 'mask':
                        xx_mask = make_pad_mask(xlens.to(self.device))
                        xx_mask = xx_mask.unsqueeze(1).repeat(
                            [1, xs.size(1),
                             1])  # `[B, emax (query), emax (key)]`
                        for chunk_idx in range(n_chunks):
                            offset = chunk_idx * N_c
                            xx_mask[:, offset:offset +
                                    N_c, :max(0, offset - N_l)] = 0
                            xx_mask[:, offset:offset + N_c,
                                    offset + (N_c + N_r):] = 0

            # Extract the center region
            if self.lc_type == 'reshape':
                xs = xs[:, N_l:N_l + N_c]  # `[B * n_chunks, N_c, d_model]`
                xs = xs.contiguous().view(bs, -1, xs.size(2))
                xs = xs[:, :emax]

        else:
            if self.pe_type == 'relative':
                xs = xs * self.scale
                # Create sinusoidal positional embeddings for relative positional encoding
                pos_idxs = torch.arange(xs.size(1) - 1,
                                        -1,
                                        -1.0,
                                        dtype=torch.float,
                                        device=self.device)
                pos_embs = self.pos_emb(pos_idxs)
            else:
                xs = self.pos_enc(xs, scale=True)
                pos_embs = None

            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens.to(self.device)).unsqueeze(1).repeat(
                [1, xs.size(1), 1])

            for lth, layer in enumerate(self.layers):
                xs = layer(xs, xx_mask, pos_embs=pos_embs)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % lth] = tensor2np(
                        layer.xx_aws)
                    self.data_dict['elens%d' % lth] = tensor2np(xlens)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub1')
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2 = self.sub_module(xs, xx_mask, lth, pos_embs,
                                              'sub2')
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)
                    # Create the self-attention mask
                    xx_mask = make_pad_mask(xlens.to(
                        self.device)).unsqueeze(1).repeat([1, xs.size(1), 1])
                    if self.pe_type == 'relative':
                        # Create sinusoidal positional embeddings for relative positional encoding
                        pos_idxs = torch.arange(xs.size(1) - 1,
                                                -1,
                                                -1.0,
                                                dtype=torch.float,
                                                device=self.device)
                        pos_embs = self.pos_emb(pos_idxs)

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts
示例#5
0
    def forward(self,
                xs,
                xlens,
                task,
                streaming=False,
                lookback=False,
                lookahead=False):
        """Forward pass.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): A list of length `[B]`
            task (str): all or ys or ys_sub1 or ys_sub2
            streaming (bool): streaming encoding
            lookback (bool): truncate leftmost frames for lookback in CNN context
            lookahead (bool): truncate rightmost frames for lookahead in CNN context
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T // prod(subsample), n_units (*2)]`
                xlens (IntTensor): `[B]`
                xs_sub1 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]`
                xlens_sub1 (IntTensor): `[B]`
                xs_sub2 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]`
                xlens_sub2 (IntTensor): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        # Sort by lengths in the descending order for pack_padded_sequence
        perm_ids_unsort = None
        if not self.lc_bidir:
            xlens, perm_ids = torch.IntTensor(xlens).sort(0, descending=True)
            xs = xs[perm_ids]
            _, perm_ids_unsort = perm_ids.sort()

        # Dropout for inputs-hidden connection
        xs = self.dropout_in(xs)

        bs, xmax, idim = xs.size()
        N_c, N_r = self.N_c, self.N_r

        if self.lc_bidir and not self.cnn_lookahead:
            xs = chunkwise(xs, 0, N_c, 0)  # `[B * n_chunks, N_c, idim]`
            # Extract the center region
            xs = xs.contiguous().view(bs, -1, xs.size(2))
            xs = xs[:, :xlens.max()]  # `[B, emax, d_model]`

        # Path through CNN blocks before RNN layers
        if self.conv is not None:
            xs, xlens = self.conv(xs,
                                  xlens,
                                  lookback=lookback,
                                  lookahead=lookahead)
            if self.enc_type == 'conv':
                eouts['ys']['xs'] = xs
                eouts['ys']['xlens'] = xlens
                return eouts
            if self.lc_bidir:
                N_c = N_c // self.conv_factor
                N_r = N_r // self.conv_factor

        carry_over = self.rsp_prob > 0 and self.training and random.random(
        ) < self.rsp_prob
        carry_over = carry_over and (bs == (self.hx_fwd[0][0].size(0) if
                                            self.hx_fwd[0] is not None else 0))
        if not streaming and not carry_over:
            self.reset_cache()
            # NOTE: do not reset here for streaming inference

        if self.lc_bidir:
            # Flip the layer and time loop
            if self.N_c <= 0:
                xs, xlens, xs_sub1, xlens_sub1 = self._forward_full_context(
                    xs, xlens)
            else:
                xs, xlens, xs_sub1, xlens_sub1 = self._forward_latency_controlled(
                    xs, xlens, N_c, N_r, streaming)
            if task == 'ys_sub1':
                eouts[task]['xs'], eouts[task]['xlens'] = xs_sub1, xlens_sub1
                return eouts
        else:
            for lth in range(self.n_layers):
                self.rnn[lth].flatten_parameters()  # for multi-GPUs
                xs, state = self.padding(xs,
                                         xlens,
                                         self.rnn[lth],
                                         prev_state=self.hx_fwd[lth],
                                         streaming=streaming)
                self.hx_fwd[lth] = state
                xs = self.dropout(xs)

                # Pick up outputs in the sub task before the projection layer
                if lth == self.n_layers_sub1 - 1:
                    xs_sub1, xlens_sub1 = self.sub_module(
                        xs, xlens, perm_ids_unsort, 'sub1')
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens_sub1
                        return eouts
                if lth == self.n_layers_sub2 - 1:
                    xs_sub2, xlens_sub2 = self.sub_module(
                        xs, xlens, perm_ids_unsort, 'sub2')
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens_sub2
                        return eouts

                # Projection layer
                if self.proj is not None and lth != self.n_layers - 1:
                    xs = torch.relu(self.proj[lth](xs))
                # Subsampling layer
                if self.subsample is not None:
                    xs, xlens = self.subsample[lth](xs, xlens)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        xs = xs[:, :xlens.max()]

        if task in ['all', 'ys']:
            if perm_ids_unsort is not None:
                xs = xs[perm_ids_unsort]
                xlens = xlens[perm_ids_unsort]
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1'][
                'xlens'] = xs_sub1, xlens_sub1
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2'][
                'xlens'] = xs_sub2, xlens_sub2
        return eouts