class TransformerEncoder(EncoderBase): """Transformer encoder. Args: input_dim (int): dimension of input features (freq * channel) enc_type (str): type of encoder attn_type (str): type of attention n_heads (int): number of heads for multi-head attention n_layers (int): number of blocks n_layers_sub1 (int): number of layers in the 1st auxiliary task n_layers_sub2 (int): number of layers in the 2nd auxiliary task d_model (int): dimension of MultiheadAttentionMechanism d_ff (int): dimension of PositionwiseFeedForward last_proj_dim (int): dimension of the last projection layer pe_type (str): type of positional encoding layer_norm_eps (float): epsilon value for layer normalization ffn_activation (str): nonolinear function for PositionwiseFeedForward dropout_in (float): dropout probability for input-hidden connection dropout (float): dropout probabilities for linear layers dropout_att (float): dropout probabilities for attention distributions dropout_residual (float): dropout probability for stochastic residual connections n_stacks (int): number of frames to stack n_splices (int): frames to splice. Default is 1 frame. conv_in_channel (int): number of channels of input features conv_channels (int): number of channles in the CNN blocks conv_kernel_sizes (list): size of kernels in the CNN blocks conv_strides (list): number of strides in the CNN blocks conv_poolings (list): size of poolings in the CNN blocks conv_batch_norm (bool): apply batch normalization only in the CNN blocks conv_layer_norm (bool): apply layer normalization only in the CNN blocks conv_bottleneck_dim (int): dimension of the bottleneck layer between CNN and self-attention layers conv_param_init (float): only for CNN layers before Transformer layers chunk_size_left (int): left chunk size for time-restricted Transformer encoder chunk_size_current (int): current chunk size for time-restricted Transformer encoder chunk_size_right (int): right chunk size for time-restricted Transformer encoder task_specific_layer (bool): add a task specific layer for each sub task param_init (str): parameter initialization method """ def __init__(self, input_dim, enc_type, attn_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, last_proj_dim, pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout, dropout_att, dropout_residual, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, chunk_size_left, chunk_size_current, chunk_size_right): super(TransformerEncoder, self).__init__() if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError('Set n_layers_sub1 between 1 to n_layers.') if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.') self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type # for latency-controlled self.chunk_size_left = chunk_size_left self.chunk_size_cur = chunk_size_current self.chunk_size_right = chunk_size_right # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs before RNNs if conv_channels: assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, attn_type, n_heads, dropout, dropout_att, dropout_residual * (l + 1) / n_layers, layer_norm_eps, ffn_activation, param_init)) for l in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, attn_type, n_heads, dropout, dropout_att, dropout_residual * n_layers_sub1 / n_layers, layer_norm_eps, ffn_activation, param_init) self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, attn_type, n_heads, dropout, dropout_att, dropout_residual * n_layers_sub2 / n_layers, layer_norm_eps, ffn_activation, param_init) self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) if last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor() if param_init == 'xavier_uniform': self.reset_parameters() def reset_parameters(self): """Initialize parameters with Xavier uniform distribution.""" logger.info( '===== Initialize %s with Xavier uniform distribution =====' % self.__class__.__name__) if self.conv is None: nn.init.xavier_uniform_(self.embed.weight) nn.init.constant_(self.embed.bias, 0.) if self.bridge is not None: nn.init.xavier_uniform_(self.bridge.weight) nn.init.constant_(self.bridge.bias, 0.) def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) if not self.training: self.data_dict['elens'] = tensor2np(xlens) bs, xmax, idim = xs.size() xs = self.pos_enc(xs) if self.chunk_size_left > 0: # Time-restricted self-attention for streaming models cs_l = self.chunk_size_left cs_c = self.chunk_size_cur cs_r = self.chunk_size_right xs_chunks = [] xx_aws = [[] for l in range(self.n_layers)] xs_pad = torch.cat([ xs.new_zeros(bs, cs_l, idim), xs, xs.new_zeros(bs, cs_r, idim) ], dim=1) # TODO: remove right padding for t in range(cs_l, cs_l + xmax, self.chunk_size_cur): xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r] for l, layer in enumerate(self.layers): xs_chunk, xx_aws_chunk = layer(xs_chunk, None) # no mask xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c, cs_l:cs_l + cs_c]) xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c]) xs = torch.cat(xs_chunks, dim=1)[:, :xmax] if not self.training: for l in range(self.n_layers): self.aws_dict['xx_aws_layer%d' % l] = tensor2np( torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax]) else: # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) for l, layer in enumerate(self.layers): xs, xx_aws = layer(xs, xx_mask) if not self.training: self.aws_dict['xx_aws_layer%d' % l] = tensor2np(xx_aws) # Pick up outputs in the sub task before the projection layer if l == self.n_layers_sub1 - 1: xs_sub1 = self.layer_sub1( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub1 = self.norm_out_sub1(xs_sub1) if self.bridge_sub1 is not None: xs_sub1 = self.bridge_sub1(xs_sub1) if task == 'ys_sub1': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub1, xlens return eouts if l == self.n_layers_sub2 - 1: xs_sub2 = self.layer_sub2( xs, xx_mask )[0] if self.task_specific_layer else xs.clone() xs_sub2 = self.norm_out_sub2(xs_sub2) if self.bridge_sub2 is not None: xs_sub2 = self.bridge_sub2(xs_sub2) if task == 'ys_sub2': eouts[task]['xs'], eouts[task][ 'xlens'] = xs_sub2, xlens return eouts xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens return eouts def _plot_attention(self, save_path, n_cols=2): """Plot attention for each head in all layers.""" from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator _save_path = mkdir_join(save_path, 'enc_att_weights') # Clean directory if _save_path is not None and os.path.isdir(_save_path): shutil.rmtree(_save_path) os.mkdir(_save_path) for k, aw in self.aws_dict.items(): elens = self.data_dict['elens'] plt.clf() n_heads = aw.shape[1] n_cols_tmp = 1 if n_heads == 1 else n_cols fig, axes = plt.subplots(max(1, n_heads // n_cols_tmp), n_cols_tmp, figsize=(20, 8), squeeze=False) for h in range(n_heads): ax = axes[h // n_cols_tmp, h % n_cols_tmp] ax.imshow(aw[-1, h, :elens[-1], :elens[-1]], aspect="auto") ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.tight_layout() fig.savefig(os.path.join(_save_path, '%s.png' % k), dvi=500) plt.close()
class RNNEncoder(EncoderBase): """RNN encoder. Args: input_dim (int): dimension of input features (freq * channel) rnn_type (str): type of encoder (including pure CNN layers) n_units (int): number of units in each layer n_projs (int): number of units in each projection layer last_proj_dim (int): dimension of the last projection layer n_layers (int): number of layers n_layers_sub1 (int): number of layers in the 1st auxiliary task n_layers_sub2 (int): number of layers in the 2nd auxiliary task dropout_in (float): dropout probability for input-hidden connection dropout (float): dropout probability for hidden-hidden connection subsample (list): subsample in the corresponding RNN layers ex.) [False, True, True, False] means that subsample is conducted in the 2nd and 3rd layers. subsample_type (str): drop/concat/max_pool n_stacks (int): number of frames to stack n_splices (int): number of frames to splice conv_in_channel (int): number of channels of input features conv_channels (int): number of channles in the CNN blocks conv_kernel_sizes (list): size of kernels in the CNN blocks conv_strides (list): number of strides in the CNN blocks conv_poolings (list): size of poolings in the CNN blocks conv_batch_norm (bool): apply batch normalization only in the CNN blocks conv_layer_norm (bool): apply layer normalization only in the CNN blocks conv_bottleneck_dim (int): dimension of the bottleneck layer between CNN and RNN layers nin (bool): insert 1*1 conv + batch normalization + ReLU bidirectional_sum_fwd_bwd (bool): task_specific_layer (bool): param_init (float): lc_chunk_size_left (int): left chunk size for latency-controlled bidirectional encoder lc_chunk_size_right (int): right chunk size for latency-controlled bidirectional encoder lc_state_reset_prob (float): probability to reset states for latency-controlled bidirectional encoder """ def __init__(self, input_dim, rnn_type, n_units, n_projs, last_proj_dim, n_layers, n_layers_sub1, n_layers_sub2, dropout_in, dropout, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, nin, bidirectional_sum_fwd_bwd, task_specific_layer, param_init, lc_chunk_size_left, lc_chunk_size_right, lc_state_reset_prob): super(RNNEncoder, self).__init__() if len(subsample) > 0 and len(subsample) != n_layers: raise ValueError('subsample must be the same size as n_layers.') if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError('Set n_layers_sub1 between 1 to n_layers.') if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.') self.rnn_type = rnn_type self.bidirectional = True if ('blstm' in rnn_type or 'bgru' in rnn_type) else False self.n_units = n_units self.n_dirs = 2 if self.bidirectional else 1 self.n_layers = n_layers # for latency-controlled self.latency_controlled = lc_chunk_size_left > 0 or lc_chunk_size_right > 0 self.lc_chunk_size_left = lc_chunk_size_left self.lc_chunk_size_right = lc_chunk_size_right self.lc_state_reset_prob = lc_state_reset_prob if self.latency_controlled: assert n_layers_sub1 == 0 assert n_layers_sub2 == 0 # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # Dropout for input-hidden connection self.dropout_in = nn.Dropout(p=dropout_in) if rnn_type == 'tds': self.conv = TDSEncoder(input_dim=input_dim * n_stacks, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, dropout=dropout, bottleneck_dim=last_proj_dim) elif rnn_type == 'gated_conv': self.conv = GatedConvEncoder(input_dim=input_dim * n_stacks, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, dropout=dropout, bottleneck_dim=last_proj_dim, param_init=param_init) elif 'conv' in rnn_type: assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, residual=False, bottleneck_dim=conv_bottleneck_dim, param_init=param_init) else: self.conv = None if self.conv is None: self._odim = input_dim * n_splices * n_stacks else: self._odim = self.conv.output_dim subsample = [1] * self.n_layers logger.warning('Subsampling is automatically ignored because CNN layers are used before RNN layers.') self.padding = Padding(bidirectional_sum_fwd_bwd=bidirectional_sum_fwd_bwd) if rnn_type not in ['conv', 'tds', 'gated_conv']: self.rnn = nn.ModuleList() if self.latency_controlled: self.rnn_bwd = nn.ModuleList() self.dropout = nn.Dropout(p=dropout) self.proj = None if n_projs > 0: self.proj = nn.ModuleList() # subsample self.subsample_layer = None if subsample_type == 'max_pool' and np.prod(subsample) > 1: self.subsample_layer = nn.ModuleList([MaxpoolSubsampler(subsample[l]) for l in range(n_layers)]) elif subsample_type == 'concat' and np.prod(subsample) > 1: self.subsample_layer = nn.ModuleList([ConcatSubsampler(subsample[l], n_units * self.n_dirs) for l in range(n_layers)]) elif subsample_type == 'drop' and np.prod(subsample) > 1: self.subsample_layer = nn.ModuleList([DropSubsampler(subsample[l]) for l in range(n_layers)]) elif subsample_type == '1dconv' and np.prod(subsample) > 1: self.subsample_layer = nn.ModuleList([Conv1dSubsampler(subsample[l], n_units * self.n_dirs) for l in range(n_layers)]) # NiN self.nin = nn.ModuleList() if nin else None for l in range(n_layers): if 'lstm' in rnn_type: rnn_i = nn.LSTM elif 'gru' in rnn_type: rnn_i = nn.GRU else: raise ValueError('rnn_type must be "(conv_)(b/lcb)lstm" or "(conv_)(b/lcb)gru".') if self.latency_controlled: self.rnn += [rnn_i(self._odim, n_units, 1, batch_first=True)] self.rnn_bwd += [rnn_i(self._odim, n_units, 1, batch_first=True)] else: self.rnn += [rnn_i(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional)] self._odim = n_units if bidirectional_sum_fwd_bwd else n_units * self.n_dirs self.bidirectional_sum_fwd_bwd = bidirectional_sum_fwd_bwd # Projection layer if self.proj is not None: if l != n_layers - 1: self.proj += [nn.Linear(n_units * self.n_dirs, n_projs)] self._odim = n_projs # Task specific layer if l == n_layers_sub1 - 1 and task_specific_layer: self.rnn_sub1 = rnn_i(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(n_units, last_proj_dim) if l == n_layers_sub2 - 1 and task_specific_layer: self.rnn_sub2 = rnn_i(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(n_units, last_proj_dim) # Network in network if self.nin is not None: if l != n_layers - 1: self.nin += [NiN(self._odim)] # if n_layers_sub1 > 0 or n_layers_sub2 > 0: # assert task_specific_layer if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor() self._factor *= np.prod(subsample) self.reset_parameters(param_init) # for streaming inference self.reset_cache() def reset_parameters(self, param_init): """Initialize parameters with uniform distribution.""" logger.info('===== Initialize %s =====' % self.__class__.__name__) for n, p in self.named_parameters(): if 'conv' in n or 'tds' in n or 'gated_conv' in n: continue # for CNN layers before RNN layers if p.dim() == 1: nn.init.constant_(p, 0.) # bias logger.info('Initialize %s with %s / %.3f' % (n, 'constant', 0.)) elif p.dim() in [2, 4]: nn.init.uniform_(p, a=-param_init, b=param_init) logger.info('Initialize %s with %s / %.3f' % (n, 'uniform', param_init)) else: raise ValueError(n) def reset_cache(self): self.fwd_states = [None] * self.n_layers logger.debug('Reset cache.') def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): A list of length `[B]` task (str): all or ys or ys_sub1 or ys_sub2 use_cache (bool): use the cached forward encoder state in the previous chunk as the initial state streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T // prod(subsample), n_units (*2)]` xlens (IntTensor): `[B]` xs_sub1 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]` xlens_sub1 (IntTensor): `[B]` xs_sub2 (FloatTensor): `[B, T // prod(subsample), n_units (*2)]` xlens_sub2 (IntTensor): `[B]` """ eouts = {'ys': {'xs': None, 'xlens': None}, 'ys_sub1': {'xs': None, 'xlens': None}, 'ys_sub2': {'xs': None, 'xlens': None}} # Sort by lenghts in the descending order for pack_padded_sequence xlens, perm_ids = torch.IntTensor(xlens).sort(0, descending=True) xs = xs[perm_ids] _, perm_ids_unsort = perm_ids.sort() # Dropout for inputs-hidden connection xs = self.dropout_in(xs) # Path through CNN blocks before RNN layers if self.conv is not None: xs, xlens = self.conv(xs, xlens) if self.rnn_type in ['conv', 'tds', 'gated_conv']: eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts if not use_cache: self.reset_cache() if self.latency_controlled: # Flip the layer and time loop xs, xlens = self._forward_streaming(xs, xlens, streaming) else: for l in range(self.n_layers): self.rnn[l].flatten_parameters() # for multi-GPUs xs, self.fwd_states[l] = self.padding(xs, xlens, self.rnn[l], prev_state=self.fwd_states[l]) xs = self.dropout(xs) # Pick up outputs in the sub task before the projection layer if l == self.n_layers_sub1 - 1: xs_sub1, xlens_sub1 = self.sub_module(xs, xlens, perm_ids_unsort, 'sub1') if task == 'ys_sub1': eouts[task]['xs'], eouts[task]['xlens'] = xs_sub1, xlens_sub1 return eouts if l == self.n_layers_sub2 - 1: xs_sub2, xlens_sub2 = self.sub_module(xs, xlens, perm_ids_unsort, 'sub2') if task == 'ys_sub2': eouts[task]['xs'], eouts[task]['xlens'] = xs_sub2, xlens_sub2 return eouts # NOTE: Exclude the last layer if l != self.n_layers - 1: # Projection layer -> Subsampling -> NiN if self.proj is not None: xs = torch.tanh(self.proj[l](xs)) if self.subsample_layer is not None: xs, xlens = self.subsample_layer[l](xs, xlens) if self.nin is not None: xs = self.nin[l](xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) # Unsort xs = xs[perm_ids_unsort] xlens = xlens[perm_ids_unsort] if task in ['all', 'ys']: eouts['ys']['xs'], eouts['ys']['xlens'] = xs, xlens if self.n_layers_sub1 >= 1 and task == 'all': eouts['ys_sub1']['xs'], eouts['ys_sub1']['xlens'] = xs_sub1, xlens_sub1 if self.n_layers_sub2 >= 1 and task == 'all': eouts['ys_sub2']['xs'], eouts['ys_sub2']['xlens'] = xs_sub2, xlens_sub2 return eouts def _forward_streaming(self, xs, xlens, streaming): """Streaming encoding for the latency-controlled bidirectional encoder. Args: xs (FloatTensor): `[B, T, n_units]` Returns: xs (FloatTensor): `[B, T, n_units]` """ cs_l = self.lc_chunk_size_left // self.subsampling_factor() cs_r = self.lc_chunk_size_right // self.subsampling_factor() # full context BPTT if cs_l < 0: for l in range(self.n_layers): self.rnn[l].flatten_parameters() # for multi-GPUs self.rnn_bwd[l].flatten_parameters() # for multi-GPUs # bwd xs_bwd = torch.flip(xs, dims=[1]) xs_bwd, _ = self.rnn_bwd[l](xs_bwd, hx=None) xs_bwd = torch.flip(xs_bwd, dims=[1]) # fwd xs_fwd, _ = self.rnn[l](xs, hx=None) if self.bidirectional_sum_fwd_bwd: xs = xs_fwd + xs_bwd else: xs = torch.cat([xs_fwd, xs_bwd], dim=-1) xs = self.dropout(xs) # Projection layer if self.proj is not None and l != self.n_layers - 1: xs = torch.tanh(self.proj[l](xs)) return xs, xlens bs, xmax, input_dim = xs.size() n_chunks = 1 if streaming else math.ceil(xmax / cs_l) xlens = torch.IntTensor(bs).fill_(cs_l if streaming else xmax) xs_chunks = [] for t in range(0, cs_l * n_chunks, cs_l): xs_chunk = xs[:, t:t + (cs_l + cs_r)] for l in range(self.n_layers): self.rnn[l].flatten_parameters() # for multi-GPUs self.rnn_bwd[l].flatten_parameters() # for multi-GPUs # bwd xs_chunk_bwd = torch.flip(xs_chunk, dims=[1]) xs_chunk_bwd, _ = self.rnn_bwd[l](xs_chunk_bwd, hx=None) xs_chunk_bwd = torch.flip(xs_chunk_bwd, dims=[1]) # `[B, cs_l+cs_r, n_units]` # fwd if xs_chunk.size(1) <= cs_l: xs_chunk_fwd, self.fwd_states[l] = self.rnn[l](xs_chunk, hx=self.fwd_states[l]) if self.training and self.lc_state_reset_prob > 0 and random.random() < self.lc_state_reset_prob: self.fwd_states[l] = None else: xs_chunk_fwd1, self.fwd_states[l] = self.rnn[l](xs_chunk[:, :cs_l], hx=self.fwd_states[l]) if self.training and self.lc_state_reset_prob > 0 and random.random() < self.lc_state_reset_prob: self.fwd_states[l] = None xs_chunk_fwd2, _ = self.rnn[l](xs_chunk[:, cs_l:], hx=self.fwd_states[l]) xs_chunk_fwd = torch.cat([xs_chunk_fwd1, xs_chunk_fwd2], dim=1) # `[B, cs_l+cs_r, n_units]` # NOTE: xs_chunk_fwd2 is for xs_chunk_bwd in the next layer if self.bidirectional_sum_fwd_bwd: xs_chunk = xs_chunk_fwd + xs_chunk_bwd else: xs_chunk = torch.cat([xs_chunk_fwd, xs_chunk_bwd], dim=-1) xs_chunk = self.dropout(xs_chunk) # Projection layer if self.proj is not None and l != self.n_layers - 1: xs_chunk = torch.tanh(self.proj[l](xs_chunk)) xs_chunks.append(xs_chunk[:, :cs_l]) xs = torch.cat(xs_chunks, dim=1) return xs, xlens def sub_module(self, xs, xlens, perm_ids_unsort, module='sub1'): if self.task_specific_layer: getattr(self, 'rnn_' + module).flatten_parameters() # for multi-GPUs xs_sub, _ = self.padding(xs, xlens, getattr(self, 'rnn_' + module)) xs_sub = self.dropout(xs_sub) else: xs_sub = xs.clone()[perm_ids_unsort] if getattr(self, 'bridge_' + module) is not None: xs_sub = getattr(self, 'bridge_' + module)(xs_sub) xlens_sub = xlens[perm_ids_unsort] return xs_sub, xlens_sub
class TransformerEncoder(EncoderBase): """Transformer encoder. Args: input_dim (int): dimension of input features (freq * channel) attn_type (str): type of attention n_heads (int): number of heads for multi-head attention n_layers (int): number of blocks d_model (int): dimension of MultiheadAttentionMechanism d_ff (int): dimension of PositionwiseFeedForward last_proj_dim (int): dimension of the last projection layer pe_type (str): type of positional encoding layer_norm_eps (float): epsilon value for layer normalization ffn_activation (str): nonolinear function for PositionwiseFeedForward dropout_in (float): dropout probability for input-hidden connection dropout (float): dropout probabilities for linear layers dropout_att (float): dropout probabilities for attention distributions n_stacks (int): number of frames to stack n_splices (int): frames to splice. Default is 1 frame. conv_in_channel (int): number of channels of input features conv_channels (int): number of channles in the CNN blocks conv_kernel_sizes (list): size of kernels in the CNN blocks conv_strides (list): number of strides in the CNN blocks conv_poolings (list): size of poolings in the CNN blocks conv_batch_norm (bool): apply batch normalization only in the CNN blocks conv_layer_norm (bool): apply layer normalization only in the CNN blocks conv_bottleneck_dim (int): dimension of the bottleneck layer between CNN and self-attention layers conv_param_init (float): only for CNN layers before Transformer layers chunk_size_left (int): left chunk size for time-restricted Transformer encoder chunk_size_current (int): current chunk size for time-restricted Transformer encoder chunk_size_right (int): right chunk size for time-restricted Transformer encoder param_init (str): """ def __init__(self, input_dim, attn_type, n_heads, n_layers, d_model, d_ff, last_proj_dim, pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout, dropout_att, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, param_init, chunk_size_left, chunk_size_current, chunk_size_right): super(TransformerEncoder, self).__init__() self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.chunk_size_left = chunk_size_left self.chunk_size_current = chunk_size_current self.chunk_size_right = chunk_size_right # Setting for CNNs before RNNs if conv_channels: assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type) self.layers = repeat( TransformerEncoderBlock(d_model, d_ff, attn_type, n_heads, dropout, dropout_att, layer_norm_eps, ffn_activation, param_init), n_layers) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim else: self.bridge = None self._odim = d_model # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor() if param_init == 'xavier_uniform': self.reset_parameters() def reset_parameters(self): """Initialize parameters with Xavier uniform distribution.""" logger.info( '===== Initialize %s with Xavier uniform distribution =====' % self.__class__.__name__) if self.conv is None: nn.init.xavier_uniform_(self.embed.weight) nn.init.constant_(self.embed.bias, 0.) if self.bridge is not None: nn.init.xavier_uniform_(self.bridge.weight) nn.init.constant_(self.bridge.bias, 0.) def forward(self, xs, xlens, task, use_cache=False, streaming=False): """Forward computation. Args: xs (FloatTensor): `[B, T, input_dim]` xlens (list): `[B]` task (str): not supported now use_cache (bool): streaming (bool): streaming encoding Returns: eouts (dict): xs (FloatTensor): `[B, T, d_model]` xlens (list): `[B]` """ eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } if self.conv is None: xs = self.embed(xs) else: # Path through CNN blocks before RNN layers xs, xlens = self.conv(xs, xlens) bs, xmax, idim = xs.size() xs = self.pos_enc(xs) if self.chunk_size_left > 0: # Time-restricted self-attention for streaming models cs_l = self.chunk_size_left cs_c = self.chunk_size_current cs_r = self.chunk_size_right hop_size = self.chunk_size_current xs_chunks = [] xx_aws = [[] for l in range(self.n_layers)] xs_pad = torch.cat([ xs.new_zeros(bs, cs_l, idim), xs, xs.new_zeros(bs, cs_r, idim) ], dim=1) # TODO: remove right padding for t in range(cs_l, cs_l + xmax, hop_size): xs_chunk = xs_pad[:, t - cs_l:t + cs_c + cs_r] for l in range(self.n_layers): xs_chunk, xx_aws_chunk = self.layers[l](xs_chunk, None) # no mask xx_aws[l].append(xx_aws_chunk[:, :, cs_l:cs_l + cs_c, cs_l:cs_l + cs_c]) xs_chunks.append(xs_chunk[:, cs_l:cs_l + cs_c]) xs = torch.cat(xs_chunks, dim=1)[:, :xmax] if not self.training: for l in range(self.n_layers): setattr( self, 'xx_aws_layer%d' % l, tensor2np( torch.cat(xx_aws[l], dim=3)[:, :, :xmax, :xmax])) else: # Create the self-attention mask xx_mask = make_pad_mask(xlens, self.device_id).unsqueeze(2).repeat( [1, 1, xmax]) for l in range(self.n_layers): xs, xx_aws = self.layers[l](xs, xx_mask) if not self.training: setattr(self, 'xx_aws_layer%d' % l, tensor2np(xx_aws)) xs = self.norm_out(xs) # Bridge layer if self.bridge is not None: xs = self.bridge(xs) eouts['ys']['xs'] = xs eouts['ys']['xlens'] = xlens return eouts def _plot_attention(self, save_path, n_cols=2): """Plot attention for each head in all layers.""" from matplotlib import pyplot as plt from matplotlib.ticker import MaxNLocator save_path = mkdir_join(save_path, 'enc_xx_att_weights') # Clean directory if save_path is not None and os.path.isdir(save_path): shutil.rmtree(save_path) os.mkdir(save_path) for l in range(self.n_layers): if not hasattr(self, 'xx_aws_layer%d' % l): continue xx_aws = getattr(self, 'xx_aws_layer%d' % l) plt.clf() fig, axes = plt.subplots(self.n_heads // n_cols, n_cols, figsize=(20, 8)) for h in range(self.n_heads): if self.n_heads > n_cols: ax = axes[h // n_cols, h % n_cols] else: ax = axes[h] ax.imshow(xx_aws[-1, h, :, :], aspect="auto") ax.grid(False) ax.set_xlabel("Input (head%d)" % h) ax.set_ylabel("Output (head%d)" % h) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) fig.tight_layout() fig.savefig(os.path.join(save_path, 'layer%d.png' % (l)), dvi=500) plt.close()