def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, ffn_activation, pe_type, layer_norm_eps, last_proj_dim, dropout_in, dropout, dropout_att, dropout_layer, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, clamp_len, lookahead, chunk_size_left, chunk_size_current, chunk_size_right, streaming_type): super(TransformerEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s # parse lookahead lookaheads = [0] * n_layers for lth, s in enumerate(list(map(int, lookahead.split('_')[:n_layers]))): lookaheads[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise Warning( 'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d' % (n_layers, n_layers_sub1)) if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise Warning( 'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d' % (n_layers_sub1, n_layers_sub2)) self.enc_type = enc_type self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) # for compatibility chunk_size_left = str(chunk_size_left) chunk_size_current = str(chunk_size_current) chunk_size_right = str(chunk_size_right) # for streaming encoder self.unidir = 'uni' in enc_type self.lookaheads = lookaheads if sum(lookaheads) > 0: assert self.unidir self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks self.chunk_size_current = int( chunk_size_current.split('_')[-1]) // n_stacks self.chunk_size_right = int( chunk_size_right.split('_')[-1]) // n_stacks self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type self.cnn_lookahead = self.unidir or enc_type == 'conv' self.streaming_type = streaming_type if self.lc_bidir else '' # -: past context # *: current context # +: future context # reshape) overlapped windowing. additional redundant computation is introduced. # During inference, caching is not applied. However, considering (N_l+N_c+N_r) is very short # and independent on layer depth, the overhead is negligible. # chunk1: |**|++ # chunk2: --|**|++ # chunk3: --|**|++ # chunk4: --|**|++ # chunk5: --|**|++ # mask) chunkwise masking. future context is restricted within the current chunk # to avoid accumuration of future context depending on the layer depth. # chunk1: |**| # chunk2: --|**| # chunk3: -- --|**| # chunk4: -- --|**| # chunk5: -- --|**| if self.unidir: assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0 if self.streaming_type == 'mask': assert self.chunk_size_right == 0 assert self.chunk_size_left == self.chunk_size_current # NOTE: this is important to cache CNN output at each chunk if self.lc_bidir: assert n_layers_sub1 == 0 assert n_layers_sub2 == 0 assert not self.unidir # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor self.subsample = None if np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) if subsample_type == 'max_pool': self.subsample = nn.ModuleList( [MaxpoolSubsampler(factor) for factor in subsamples]) elif subsample_type == 'concat': self.subsample = nn.ModuleList([ ConcatSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'drop': self.subsample = nn.ModuleList( [DropSubsampler(factor) for factor in subsamples]) elif subsample_type == '1dconv': self.subsample = nn.ModuleList([ Conv1dSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'add': self.subsample = nn.ModuleList( [AddSubsampler(factor) for factor in subsamples]) if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.pos_enc, self.pos_emb = None, None self.u_bias, self.v_bias = None, None if pe_type in ['relative', 'relative_xl']: self.pos_emb = XLPositionalEmbedding(d_model, dropout) if pe_type == 'relative_xl': self.u_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) self.v_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u_bias and v_bias are global parameters shared in the whole model else: self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type, param_init) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim) odim_sub1 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) odim_sub1 = last_proj_dim if n_layers_sub1 == n_layers: self.norm_out_sub1 = None else: self.norm_out_sub1 = nn.LayerNorm(odim_sub1, eps=layer_norm_eps) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim) odim_sub2 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) odim_sub2 = last_proj_dim if n_layers_sub2 == n_layers: self.norm_out_sub2 = None else: self.norm_out_sub2 = nn.LayerNorm(odim_sub2, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init) # for streaming inference self.reset_cache()
def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, ffn_activation, pe_type, layer_norm_eps, last_proj_dim, dropout_in, dropout, dropout_att, dropout_layer, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, clamp_len, lookahead, chunk_size_left, chunk_size_current, chunk_size_right, streaming_type): super(TransformerEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s # parse lookahead lookaheads = [0] * n_layers for lth, s in enumerate(list(map(int, lookahead.split('_')[:n_layers]))): lookaheads[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError('Set n_layers_sub1 between 1 to n_layers.') if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.') self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) self.unidir = 'uni' in enc_type # for compatibility chunk_size_left = str(chunk_size_left) chunk_size_current = str(chunk_size_current) chunk_size_right = str(chunk_size_right) # for streaming encoder self.lookaheads = lookaheads if sum(lookaheads) > 0: assert self.unidir self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks self.chunk_size_current = int( chunk_size_current.split('_')[-1]) // n_stacks self.chunk_size_right = int( chunk_size_right.split('_')[-1]) // n_stacks self.lc_bidir = self.chunk_size_left > 0 or self.chunk_size_current > 0 or self.chunk_size_right > 0 self.streaming_type = streaming_type # reshape) not lookahead frames in CNN layers, but requires some additional computations # mask) there are some lookahead frames in CNN layers, no additional computations if self.lc_bidir: assert n_layers_sub1 == 0 assert n_layers_sub2 == 0 assert not self.unidir # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor self.subsample = None if np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) if subsample_type == 'max_pool': self.subsample = nn.ModuleList( [MaxpoolSubsampler(factor) for factor in subsamples]) elif subsample_type == 'concat': self.subsample = nn.ModuleList([ ConcatSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'drop': self.subsample = nn.ModuleList( [DropSubsampler(factor) for factor in subsamples]) elif subsample_type == '1dconv': self.subsample = nn.ModuleList([ Conv1dSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'add': self.subsample = nn.ModuleList( [AddSubsampler(factor) for factor in subsamples]) if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.clamp_len = clamp_len self.u_bias, self.v_bias = None, None self.pos_emb = None if pe_type == 'relative_xl': self.pos_emb = XLPositionalEmbedding(d_model, dropout) self.u_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) self.v_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u_bias and v_bias are global parameters elif pe_type == 'relative': self.pos_emb = XLPositionalEmbedding(d_model, dropout) else: self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type, param_init) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, ffn_bottleneck_dim) odim_sub1 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) odim_sub1 = last_proj_dim if n_layers_sub1 == n_layers: self.norm_out_sub1 = None else: self.norm_out_sub1 = nn.LayerNorm(odim_sub1, eps=layer_norm_eps) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, ffn_bottleneck_dim) odim_sub2 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) odim_sub2 = last_proj_dim if n_layers_sub2 == n_layers: self.norm_out_sub2 = None else: self.norm_out_sub2 = nn.LayerNorm(odim_sub2, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init)