Пример #1
0
class TransformerEncoder(EncoderBase):
    """Transformer encoder.

    Args:
        input_dim (int): dimension of input features (freq * channel)
        enc_type (str): type of encoder
        attn_type (str): type of attention
        n_heads (int): number of heads for multi-head attention
        n_layers (int): number of blocks
        n_layers_sub1 (int): number of layers in the 1st auxiliary task
        n_layers_sub2 (int): number of layers in the 2nd auxiliary task
        d_model (int): dimension of MultiheadAttentionMechanism
        d_ff (int): dimension of PositionwiseFeedForward
        last_proj_dim (int): dimension of the last projection layer
        pe_type (str): type of positional encoding
        layer_norm_eps (float): epsilon value for layer normalization
        ffn_activation (str): nonolinear function for PositionwiseFeedForward
        dropout_in (float): dropout probability for input-hidden connection
        dropout (float): dropout probabilities for linear layers
        dropout_att (float): dropout probabilities for attention distributions
        dropout_residual (float): dropout probability for stochastic residual connections
        n_stacks (int): number of frames to stack
        n_splices (int): frames to splice. Default is 1 frame.
        conv_in_channel (int): number of channels of input features
        conv_channels (int): number of channles in the CNN blocks
        conv_kernel_sizes (list): size of kernels in the CNN blocks
        conv_strides (list): number of strides in the CNN blocks
        conv_poolings (list): size of poolings in the CNN blocks
        conv_batch_norm (bool): apply batch normalization only in the CNN blocks
        conv_layer_norm (bool): apply layer normalization only in the CNN blocks
        conv_bottleneck_dim (int): dimension of the bottleneck layer between CNN and self-attention layers
        conv_param_init (float): only for CNN layers before Transformer layers
        chunk_size_left (int): left chunk size for time-restricted Transformer encoder
        chunk_size_current (int): current chunk size for time-restricted Transformer encoder
        chunk_size_right (int): right chunk size for time-restricted Transformer encoder
        task_specific_layer (bool): add a task specific layer for each sub task
        param_init (str): parameter initialization method

    """
    def __init__(self, input_dim, enc_type, attn_type, n_heads, n_layers,
                 n_layers_sub1, n_layers_sub2, d_model, d_ff, last_proj_dim,
                 pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout,
                 dropout_att, dropout_residual, n_stacks, n_splices,
                 conv_in_channel, conv_channels, conv_kernel_sizes,
                 conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm,
                 conv_bottleneck_dim, conv_param_init, task_specific_layer,
                 param_init, chunk_size_left, chunk_size_current,
                 chunk_size_right):

        super(TransformerEncoder, self).__init__()

        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type

        # for latency-controlled
        self.chunk_size_left = chunk_size_left
        self.chunk_size_cur = chunk_size_current
        self.chunk_size_right = chunk_size_right

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs before RNNs
        if conv_channels:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model, d_ff, attn_type, n_heads,
                                        dropout, dropout_att,
                                        dropout_residual * (l + 1) / n_layers,
                                        layer_norm_eps, ffn_activation,
                                        param_init)) for l in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model, d_ff, attn_type, n_heads, dropout, dropout_att,
                    dropout_residual * n_layers_sub1 / n_layers,
                    layer_norm_eps, ffn_activation, param_init)
            self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model, d_ff, attn_type, n_heads, dropout, dropout_att,
                    dropout_residual * n_layers_sub2 / n_layers,
                    layer_norm_eps, ffn_activation, param_init)
            self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)

        if last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor()

        if param_init == 'xavier_uniform':
            self.reset_parameters()

    def reset_parameters(self):
        """Initialize parameters with Xavier uniform distribution."""
        logger.info(
            '===== Initialize %s with Xavier uniform distribution =====' %
            self.__class__.__name__)
        if self.conv is None:
            nn.init.xavier_uniform_(self.embed.weight)
            nn.init.constant_(self.embed.bias, 0.)
        if self.bridge is not None:
            nn.init.xavier_uniform_(self.bridge.weight)
            nn.init.constant_(self.bridge.bias, 0.)

    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)
        if not self.training:
            self.data_dict['elens'] = tensor2np(xlens)

        bs, xmax, idim = xs.size()
        xs = self.pos_enc(xs)
        if self.chunk_size_left > 0:
            # Time-restricted self-attention for streaming models
            cs_l = self.chunk_size_left
            cs_c = self.chunk_size_cur
            cs_r = self.chunk_size_right
            xs_chunks = []
            xx_aws = [[] for l in range(self.n_layers)]
            xs_pad = torch.cat([
                xs.new_zeros(bs, cs_l, idim), xs,
                xs.new_zeros(bs, cs_r, idim)
            ],
                               dim=1)
            # TODO: remove right padding
            for t in range(cs_l, cs_l + xmax, self.chunk_size_cur):
                xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r]
                for l, layer in enumerate(self.layers):
                    xs_chunk, xx_aws_chunk = layer(xs_chunk, None)  # no mask
                    xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c,
                                                  cs_l:cs_l + cs_c])
                xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c])
            xs = torch.cat(xs_chunks, dim=1)[:, :xmax]
            if not self.training:
                for l in range(self.n_layers):
                    self.aws_dict['xx_aws_layer%d' % l] = tensor2np(
                        torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax])
        else:
            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            for l, layer in enumerate(self.layers):
                xs, xx_aws = layer(xs, xx_mask)
                if not self.training:
                    self.aws_dict['xx_aws_layer%d' % l] = tensor2np(xx_aws)

                # Pick up outputs in the sub task before the projection layer
                if l == self.n_layers_sub1 - 1:
                    xs_sub1 = self.layer_sub1(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub1 = self.norm_out_sub1(xs_sub1)
                    if self.bridge_sub1 is not None:
                        xs_sub1 = self.bridge_sub1(xs_sub1)
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub1, xlens
                        return eouts
                if l == self.n_layers_sub2 - 1:
                    xs_sub2 = self.layer_sub2(
                        xs, xx_mask
                    )[0] if self.task_specific_layer else xs.clone()
                    xs_sub2 = self.norm_out_sub2(xs_sub2)
                    if self.bridge_sub2 is not None:
                        xs_sub2 = self.bridge_sub2(xs_sub2)
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task][
                            'xlens'] = xs_sub2, xlens
                        return eouts

        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens
        return eouts

    def _plot_attention(self, save_path, n_cols=2):
        """Plot attention for each head in all layers."""
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        _save_path = mkdir_join(save_path, 'enc_att_weights')

        # Clean directory
        if _save_path is not None and os.path.isdir(_save_path):
            shutil.rmtree(_save_path)
            os.mkdir(_save_path)

        for k, aw in self.aws_dict.items():
            elens = self.data_dict['elens']

            plt.clf()
            n_heads = aw.shape[1]
            n_cols_tmp = 1 if n_heads == 1 else n_cols
            fig, axes = plt.subplots(max(1, n_heads // n_cols_tmp),
                                     n_cols_tmp,
                                     figsize=(20, 8),
                                     squeeze=False)
            for h in range(n_heads):
                ax = axes[h // n_cols_tmp, h % n_cols_tmp]
                ax.imshow(aw[-1, h, :elens[-1], :elens[-1]], aspect="auto")
                ax.grid(False)
                ax.set_xlabel("Input (head%d)" % h)
                ax.set_ylabel("Output (head%d)" % h)
                ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout()
            fig.savefig(os.path.join(_save_path, '%s.png' % k), dvi=500)
            plt.close()
Пример #2
0
class RNNEncoder(EncoderBase):
    """RNN encoder.

    Args:
        input_dim (int): dimension of input features (freq * channel)
        rnn_type (str): type of encoder (including pure CNN layers)
        n_units (int): number of units in each layer
        n_projs (int): number of units in each projection layer
        last_proj_dim (int): dimension of the last projection layer
        n_layers (int): number of layers
        n_layers_sub1 (int): number of layers in the 1st auxiliary task
        n_layers_sub2 (int): number of layers in the 2nd auxiliary task
        dropout_in (float): dropout probability for input-hidden connection
        dropout (float): dropout probability for hidden-hidden connection
        subsample (list): subsample in the corresponding RNN layers
            ex.) [False, True, True, False] means that subsample is conducted in the 2nd and 3rd layers.
        subsample_type (str): drop/concat/max_pool
        n_stacks (int): number of frames to stack
        n_splices (int): number of frames to splice
        conv_in_channel (int): number of channels of input features
        conv_channels (int): number of channles in the CNN blocks
        conv_kernel_sizes (list): size of kernels in the CNN blocks
        conv_strides (list): number of strides in the CNN blocks
        conv_poolings (list): size of poolings in the CNN blocks
        conv_batch_norm (bool): apply batch normalization only in the CNN blocks
        conv_layer_norm (bool): apply layer normalization only in the CNN blocks
        conv_bottleneck_dim (int): dimension of the bottleneck layer between CNN and RNN layers
        nin (bool): insert 1*1 conv + batch normalization + ReLU
        bidirectional_sum_fwd_bwd (bool):
        task_specific_layer (bool):
        param_init (float):
        lc_chunk_size_left (int): left chunk size for latency-controlled bidirectional encoder
        lc_chunk_size_right (int): right chunk size for latency-controlled bidirectional encoder
        lc_state_reset_prob (float): probability to reset states for latency-controlled bidirectional encoder

    """

    def __init__(self, input_dim, rnn_type, n_units, n_projs, last_proj_dim,
                 n_layers, n_layers_sub1, n_layers_sub2,
                 dropout_in, dropout,
                 subsample, subsample_type, n_stacks, n_splices,
                 conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 nin, bidirectional_sum_fwd_bwd,
                 task_specific_layer, param_init,
                 lc_chunk_size_left, lc_chunk_size_right, lc_state_reset_prob):

        super(RNNEncoder, self).__init__()

        if len(subsample) > 0 and len(subsample) != n_layers:
            raise ValueError('subsample must be the same size as n_layers.')
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')

        self.rnn_type = rnn_type
        self.bidirectional = True if ('blstm' in rnn_type or 'bgru' in rnn_type) else False
        self.n_units = n_units
        self.n_dirs = 2 if self.bidirectional else 1
        self.n_layers = n_layers

        # for latency-controlled
        self.latency_controlled = lc_chunk_size_left > 0 or lc_chunk_size_right > 0
        self.lc_chunk_size_left = lc_chunk_size_left
        self.lc_chunk_size_right = lc_chunk_size_right
        self.lc_state_reset_prob = lc_state_reset_prob
        if self.latency_controlled:
            assert n_layers_sub1 == 0
            assert n_layers_sub2 == 0

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # Dropout for input-hidden connection
        self.dropout_in = nn.Dropout(p=dropout_in)

        if rnn_type == 'tds':
            self.conv = TDSEncoder(input_dim=input_dim * n_stacks,
                                   in_channel=conv_in_channel,
                                   channels=conv_channels,
                                   kernel_sizes=conv_kernel_sizes,
                                   dropout=dropout,
                                   bottleneck_dim=last_proj_dim)
        elif rnn_type == 'gated_conv':
            self.conv = GatedConvEncoder(input_dim=input_dim * n_stacks,
                                         in_channel=conv_in_channel,
                                         channels=conv_channels,
                                         kernel_sizes=conv_kernel_sizes,
                                         dropout=dropout,
                                         bottleneck_dim=last_proj_dim,
                                         param_init=param_init)

        elif 'conv' in rnn_type:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    residual=False,
                                    bottleneck_dim=conv_bottleneck_dim,
                                    param_init=param_init)
        else:
            self.conv = None

        if self.conv is None:
            self._odim = input_dim * n_splices * n_stacks
        else:
            self._odim = self.conv.output_dim
            subsample = [1] * self.n_layers
            logger.warning('Subsampling is automatically ignored because CNN layers are used before RNN layers.')

        self.padding = Padding(bidirectional_sum_fwd_bwd=bidirectional_sum_fwd_bwd)

        if rnn_type not in ['conv', 'tds', 'gated_conv']:
            self.rnn = nn.ModuleList()
            if self.latency_controlled:
                self.rnn_bwd = nn.ModuleList()
            self.dropout = nn.Dropout(p=dropout)
            self.proj = None
            if n_projs > 0:
                self.proj = nn.ModuleList()

            # subsample
            self.subsample_layer = None
            if subsample_type == 'max_pool' and np.prod(subsample) > 1:
                self.subsample_layer = nn.ModuleList([MaxpoolSubsampler(subsample[l])
                                                      for l in range(n_layers)])
            elif subsample_type == 'concat' and np.prod(subsample) > 1:
                self.subsample_layer = nn.ModuleList([ConcatSubsampler(subsample[l], n_units * self.n_dirs)
                                                      for l in range(n_layers)])
            elif subsample_type == 'drop' and np.prod(subsample) > 1:
                self.subsample_layer = nn.ModuleList([DropSubsampler(subsample[l])
                                                      for l in range(n_layers)])
            elif subsample_type == '1dconv' and np.prod(subsample) > 1:
                self.subsample_layer = nn.ModuleList([Conv1dSubsampler(subsample[l], n_units * self.n_dirs)
                                                      for l in range(n_layers)])

            # NiN
            self.nin = nn.ModuleList() if nin else None

            for l in range(n_layers):
                if 'lstm' in rnn_type:
                    rnn_i = nn.LSTM
                elif 'gru' in rnn_type:
                    rnn_i = nn.GRU
                else:
                    raise ValueError('rnn_type must be "(conv_)(b/lcb)lstm" or "(conv_)(b/lcb)gru".')

                if self.latency_controlled:
                    self.rnn += [rnn_i(self._odim, n_units, 1, batch_first=True)]
                    self.rnn_bwd += [rnn_i(self._odim, n_units, 1, batch_first=True)]
                else:
                    self.rnn += [rnn_i(self._odim, n_units, 1, batch_first=True,
                                       bidirectional=self.bidirectional)]
                self._odim = n_units if bidirectional_sum_fwd_bwd else n_units * self.n_dirs
                self.bidirectional_sum_fwd_bwd = bidirectional_sum_fwd_bwd

                # Projection layer
                if self.proj is not None:
                    if l != n_layers - 1:
                        self.proj += [nn.Linear(n_units * self.n_dirs, n_projs)]
                        self._odim = n_projs

                # Task specific layer
                if l == n_layers_sub1 - 1 and task_specific_layer:
                    self.rnn_sub1 = rnn_i(self._odim, n_units, 1,
                                          batch_first=True,
                                          bidirectional=self.bidirectional)
                    if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                        self.bridge_sub1 = nn.Linear(n_units, last_proj_dim)
                if l == n_layers_sub2 - 1 and task_specific_layer:
                    self.rnn_sub2 = rnn_i(self._odim, n_units, 1,
                                          batch_first=True,
                                          bidirectional=self.bidirectional)
                    if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                        self.bridge_sub2 = nn.Linear(n_units, last_proj_dim)

                # Network in network
                if self.nin is not None:
                    if l != n_layers - 1:
                        self.nin += [NiN(self._odim)]
                    # if n_layers_sub1 > 0 or n_layers_sub2 > 0:
                    #     assert task_specific_layer

            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge = nn.Linear(self._odim, last_proj_dim)
                self._odim = last_proj_dim

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor()
        self._factor *= np.prod(subsample)

        self.reset_parameters(param_init)

        # for streaming inference
        self.reset_cache()

    def reset_parameters(self, param_init):
        """Initialize parameters with uniform distribution."""
        logger.info('===== Initialize %s =====' % self.__class__.__name__)
        for n, p in self.named_parameters():
            if 'conv' in n or 'tds' in n or 'gated_conv' in n:
                continue  # for CNN layers before RNN layers
            if p.dim() == 1:
                nn.init.constant_(p, 0.)  # bias
                logger.info('Initialize %s with %s / %.3f' % (n, 'constant', 0.))
            elif p.dim() in [2, 4]:
                nn.init.uniform_(p, a=-param_init, b=param_init)
                logger.info('Initialize %s with %s / %.3f' % (n, 'uniform', param_init))
            else:
                raise ValueError(n)

    def reset_cache(self):
        self.fwd_states = [None] * self.n_layers
        logger.debug('Reset cache.')

    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): A list of length `[B]`
            task (str): all or ys or ys_sub1 or ys_sub2
            use_cache (bool): use the cached forward encoder state in the previous chunk as the initial state
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T // prod(subsample), n_units (*2)]`
                xlens (IntTensor): `[B]`
                xs_sub1 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]`
                xlens_sub1 (IntTensor): `[B]`
                xs_sub2 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]`
                xlens_sub2 (IntTensor): `[B]`

        """
        eouts = {'ys': {'xs': None, 'xlens': None},
                 'ys_sub1': {'xs': None, 'xlens': None},
                 'ys_sub2': {'xs': None, 'xlens': None}}

        # Sort by lenghts in the descending order for pack_padded_sequence
        xlens, perm_ids = torch.IntTensor(xlens).sort(0, descending=True)
        xs = xs[perm_ids]
        _, perm_ids_unsort = perm_ids.sort()

        # Dropout for inputs-hidden connection
        xs = self.dropout_in(xs)

        # Path through CNN blocks before RNN layers
        if self.conv is not None:
            xs, xlens = self.conv(xs, xlens)
            if self.rnn_type in ['conv', 'tds', 'gated_conv']:
                eouts['ys']['xs'] = xs
                eouts['ys']['xlens'] = xlens
                return eouts

        if not use_cache:
            self.reset_cache()

        if self.latency_controlled:
            # Flip the layer and time loop
            xs, xlens = self._forward_streaming(xs, xlens, streaming)
        else:
            for l in range(self.n_layers):
                self.rnn[l].flatten_parameters()  # for multi-GPUs
                xs, self.fwd_states[l] = self.padding(xs, xlens, self.rnn[l],
                                                      prev_state=self.fwd_states[l])
                xs = self.dropout(xs)

                # Pick up outputs in the sub task before the projection layer
                if l == self.n_layers_sub1 - 1:
                    xs_sub1, xlens_sub1 = self.sub_module(xs, xlens, perm_ids_unsort, 'sub1')
                    if task == 'ys_sub1':
                        eouts[task]['xs'], eouts[task]['xlens'] = xs_sub1, xlens_sub1
                        return eouts
                if l == self.n_layers_sub2 - 1:
                    xs_sub2, xlens_sub2 = self.sub_module(xs, xlens, perm_ids_unsort, 'sub2')
                    if task == 'ys_sub2':
                        eouts[task]['xs'], eouts[task]['xlens'] = xs_sub2, xlens_sub2
                        return eouts

                # NOTE: Exclude the last layer
                if l != self.n_layers - 1:
                    # Projection layer -> Subsampling -> NiN
                    if self.proj is not None:
                        xs = torch.tanh(self.proj[l](xs))
                    if self.subsample_layer is not None:
                        xs, xlens = self.subsample_layer[l](xs, xlens)
                    if self.nin is not None:
                        xs = self.nin[l](xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        # Unsort
        xs = xs[perm_ids_unsort]
        xlens = xlens[perm_ids_unsort]

        if task in ['all', 'ys']:
            eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens
        if self.n_layers_sub1 >= 1 and task == 'all':
            eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens_sub1
        if self.n_layers_sub2 >= 1 and task == 'all':
            eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens_sub2
        return eouts

    def _forward_streaming(self, xs, xlens, streaming):
        """Streaming encoding for the latency-controlled bidirectional encoder.

        Args:
            xs (FloatTensor): `[B, T, n_units]`
        Returns:
            xs (FloatTensor): `[B, T, n_units]`

        """
        cs_l = self.lc_chunk_size_left // self.subsampling_factor()
        cs_r = self.lc_chunk_size_right // self.subsampling_factor()

        # full context BPTT
        if cs_l < 0:
            for l in range(self.n_layers):
                self.rnn[l].flatten_parameters()  # for multi-GPUs
                self.rnn_bwd[l].flatten_parameters()  # for multi-GPUs
                # bwd
                xs_bwd = torch.flip(xs, dims=[1])
                xs_bwd, _ = self.rnn_bwd[l](xs_bwd, hx=None)
                xs_bwd = torch.flip(xs_bwd, dims=[1])
                # fwd
                xs_fwd, _ = self.rnn[l](xs, hx=None)
                if self.bidirectional_sum_fwd_bwd:
                    xs = xs_fwd + xs_bwd
                else:
                    xs = torch.cat([xs_fwd, xs_bwd], dim=-1)
                xs = self.dropout(xs)

                # Projection layer
                if self.proj is not None and l != self.n_layers - 1:
                    xs = torch.tanh(self.proj[l](xs))
            return xs, xlens

        bs, xmax, input_dim = xs.size()
        n_chunks = 1 if streaming else math.ceil(xmax / cs_l)
        xlens = torch.IntTensor(bs).fill_(cs_l if streaming else xmax)

        xs_chunks = []
        for t in range(0, cs_l * n_chunks, cs_l):
            xs_chunk = xs[:, t:t + (cs_l + cs_r)]
            for l in range(self.n_layers):
                self.rnn[l].flatten_parameters()  # for multi-GPUs
                self.rnn_bwd[l].flatten_parameters()  # for multi-GPUs
                # bwd
                xs_chunk_bwd = torch.flip(xs_chunk, dims=[1])
                xs_chunk_bwd, _ = self.rnn_bwd[l](xs_chunk_bwd, hx=None)
                xs_chunk_bwd = torch.flip(xs_chunk_bwd, dims=[1])  # `[B, cs_l+cs_r, n_units]`
                # fwd
                if xs_chunk.size(1) <= cs_l:
                    xs_chunk_fwd, self.fwd_states[l] = self.rnn[l](xs_chunk, hx=self.fwd_states[l])
                    if self.training and self.lc_state_reset_prob > 0 and random.random() < self.lc_state_reset_prob:
                        self.fwd_states[l] = None
                else:
                    xs_chunk_fwd1, self.fwd_states[l] = self.rnn[l](xs_chunk[:, :cs_l], hx=self.fwd_states[l])
                    if self.training and self.lc_state_reset_prob > 0 and random.random() < self.lc_state_reset_prob:
                        self.fwd_states[l] = None
                    xs_chunk_fwd2, _ = self.rnn[l](xs_chunk[:, cs_l:], hx=self.fwd_states[l])
                    xs_chunk_fwd = torch.cat([xs_chunk_fwd1, xs_chunk_fwd2], dim=1)  # `[B, cs_l+cs_r, n_units]`
                    # NOTE: xs_chunk_fwd2 is for xs_chunk_bwd in the next layer
                if self.bidirectional_sum_fwd_bwd:
                    xs_chunk = xs_chunk_fwd + xs_chunk_bwd
                else:
                    xs_chunk = torch.cat([xs_chunk_fwd, xs_chunk_bwd], dim=-1)
                xs_chunk = self.dropout(xs_chunk)

                # Projection layer
                if self.proj is not None and l != self.n_layers - 1:
                    xs_chunk = torch.tanh(self.proj[l](xs_chunk))
            xs_chunks.append(xs_chunk[:, :cs_l])
        xs = torch.cat(xs_chunks, dim=1)

        return xs, xlens

    def sub_module(self, xs, xlens, perm_ids_unsort, module='sub1'):
        if self.task_specific_layer:
            getattr(self, 'rnn_' + module).flatten_parameters()  # for multi-GPUs
            xs_sub, _ = self.padding(xs, xlens, getattr(self, 'rnn_' + module))
            xs_sub = self.dropout(xs_sub)
        else:
            xs_sub = xs.clone()[perm_ids_unsort]
        if getattr(self, 'bridge_' + module) is not None:
            xs_sub = getattr(self, 'bridge_' + module)(xs_sub)
        xlens_sub = xlens[perm_ids_unsort]
        return xs_sub, xlens_sub
Пример #3
0
class TransformerEncoder(EncoderBase):
    """Transformer encoder.

    Args:
        input_dim (int): dimension of input features (freq * channel)
        attn_type (str): type of attention
        n_heads (int): number of heads for multi-head attention
        n_layers (int): number of blocks
        d_model (int): dimension of MultiheadAttentionMechanism
        d_ff (int): dimension of PositionwiseFeedForward
        last_proj_dim (int): dimension of the last projection layer
        pe_type (str): type of positional encoding
        layer_norm_eps (float): epsilon value for layer normalization
        ffn_activation (str): nonolinear function for PositionwiseFeedForward
        dropout_in (float): dropout probability for input-hidden connection
        dropout (float): dropout probabilities for linear layers
        dropout_att (float): dropout probabilities for attention distributions
        n_stacks (int): number of frames to stack
        n_splices (int): frames to splice. Default is 1 frame.
        conv_in_channel (int): number of channels of input features
        conv_channels (int): number of channles in the CNN blocks
        conv_kernel_sizes (list): size of kernels in the CNN blocks
        conv_strides (list): number of strides in the CNN blocks
        conv_poolings (list): size of poolings in the CNN blocks
        conv_batch_norm (bool): apply batch normalization only in the CNN blocks
        conv_layer_norm (bool): apply layer normalization only in the CNN blocks
        conv_bottleneck_dim (int): dimension of the bottleneck layer between CNN and self-attention layers
        conv_param_init (float): only for CNN layers before Transformer layers
        chunk_size_left (int): left chunk size for time-restricted Transformer encoder
        chunk_size_current (int): current chunk size for time-restricted Transformer encoder
        chunk_size_right (int): right chunk size for time-restricted Transformer encoder
        param_init (str):

    """
    def __init__(self, input_dim, attn_type, n_heads, n_layers, d_model, d_ff,
                 last_proj_dim, pe_type, layer_norm_eps, ffn_activation,
                 dropout_in, dropout, dropout_att, n_stacks, n_splices,
                 conv_in_channel, conv_channels, conv_kernel_sizes,
                 conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm,
                 conv_bottleneck_dim, conv_param_init, param_init,
                 chunk_size_left, chunk_size_current, chunk_size_right):

        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.chunk_size_left = chunk_size_left
        self.chunk_size_current = chunk_size_current
        self.chunk_size_right = chunk_size_right

        # Setting for CNNs before RNNs
        if conv_channels:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layers = repeat(
            TransformerEncoderBlock(d_model, d_ff, attn_type, n_heads, dropout,
                                    dropout_att, layer_norm_eps,
                                    ffn_activation, param_init), n_layers)
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        if last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim
        else:
            self.bridge = None
            self._odim = d_model

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor()

        if param_init == 'xavier_uniform':
            self.reset_parameters()

    def reset_parameters(self):
        """Initialize parameters with Xavier uniform distribution."""
        logger.info(
            '===== Initialize %s with Xavier uniform distribution =====' %
            self.__class__.__name__)
        if self.conv is None:
            nn.init.xavier_uniform_(self.embed.weight)
            nn.init.constant_(self.embed.bias, 0.)
        if self.bridge is not None:
            nn.init.xavier_uniform_(self.bridge.weight)
            nn.init.constant_(self.bridge.bias, 0.)

    def forward(self, xs, xlens, task, use_cache=False, streaming=False):
        """Forward computation.

        Args:
            xs (FloatTensor): `[B, T, input_dim]`
            xlens (list): `[B]`
            task (str): not supported now
            use_cache (bool):
            streaming (bool): streaming encoding
        Returns:
            eouts (dict):
                xs (FloatTensor): `[B, T, d_model]`
                xlens (list): `[B]`

        """
        eouts = {
            'ys': {
                'xs': None,
                'xlens': None
            },
            'ys_sub1': {
                'xs': None,
                'xlens': None
            },
            'ys_sub2': {
                'xs': None,
                'xlens': None
            }
        }

        if self.conv is None:
            xs = self.embed(xs)
        else:
            # Path through CNN blocks before RNN layers
            xs, xlens = self.conv(xs, xlens)

        bs, xmax, idim = xs.size()
        xs = self.pos_enc(xs)
        if self.chunk_size_left > 0:
            # Time-restricted self-attention for streaming models
            cs_l = self.chunk_size_left
            cs_c = self.chunk_size_current
            cs_r = self.chunk_size_right
            hop_size = self.chunk_size_current
            xs_chunks = []
            xx_aws = [[] for l in range(self.n_layers)]
            xs_pad = torch.cat([
                xs.new_zeros(bs, cs_l, idim), xs,
                xs.new_zeros(bs, cs_r, idim)
            ],
                               dim=1)
            # TODO: remove right padding
            for t in range(cs_l, cs_l + xmax, hop_size):
                xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r]
                for l in range(self.n_layers):
                    xs_chunk, xx_aws_chunk = self.layers[l](xs_chunk,
                                                            None)  # no mask
                    xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c,
                                                  cs_l:cs_l + cs_c])
                xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c])
            xs = torch.cat(xs_chunks, dim=1)[:, :xmax]
            if not self.training:
                for l in range(self.n_layers):
                    setattr(
                        self, 'xx_aws_layer%d' % l,
                        tensor2np(
                            torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax]))
        else:
            # Create the self-attention mask
            xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat(
                [1, 1, xmax])

            for l in range(self.n_layers):
                xs, xx_aws = self.layers[l](xs, xx_mask)
                if not self.training:
                    setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws))
        xs = self.norm_out(xs)

        # Bridge layer
        if self.bridge is not None:
            xs = self.bridge(xs)

        eouts['ys']['xs'] = xs
        eouts['ys']['xlens'] = xlens
        return eouts

    def _plot_attention(self, save_path, n_cols=2):
        """Plot attention for each head in all layers."""
        from matplotlib import pyplot as plt
        from matplotlib.ticker import MaxNLocator

        save_path = mkdir_join(save_path, 'enc_xx_att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        for l in range(self.n_layers):
            if not hasattr(self, 'xx_aws_layer%d' % l):
                continue

            xx_aws = getattr(self, 'xx_aws_layer%d' % l)

            plt.clf()
            fig, axes = plt.subplots(self.n_heads // n_cols,
                                     n_cols,
                                     figsize=(20, 8))
            for h in range(self.n_heads):
                if self.n_heads > n_cols:
                    ax = axes[h // n_cols, h % n_cols]
                else:
                    ax = axes[h]
                ax.imshow(xx_aws[-1, h, :, :], aspect="auto")
                ax.grid(False)
                ax.set_xlabel("Input (head%d)" % h)
                ax.set_ylabel("Output (head%d)" % h)
                ax.xaxis.set_major_locator(MaxNLocator(integer=True))
                ax.yaxis.set_major_locator(MaxNLocator(integer=True))

            fig.tight_layout()
            fig.savefig(os.path.join(save_path, 'layer%d.png' % (l)), dvi=500)
            plt.close()