Пример #1
0
    def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1,
                 n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim,
                 ffn_activation, pe_type, layer_norm_eps, last_proj_dim,
                 dropout_in, dropout, dropout_att, dropout_layer, subsample,
                 subsample_type, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init, clamp_len,
                 lookahead, chunk_size_left, chunk_size_current,
                 chunk_size_right, streaming_type):

        super(TransformerEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s
        # parse lookahead
        lookaheads = [0] * n_layers
        for lth, s in enumerate(list(map(int,
                                         lookahead.split('_')[:n_layers]))):
            lookaheads[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise Warning(
                'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d'
                % (n_layers, n_layers_sub1))
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise Warning(
                'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d'
                % (n_layers_sub1, n_layers_sub2))

        self.enc_type = enc_type
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)

        # for compatibility
        chunk_size_left = str(chunk_size_left)
        chunk_size_current = str(chunk_size_current)
        chunk_size_right = str(chunk_size_right)

        # for streaming encoder
        self.unidir = 'uni' in enc_type
        self.lookaheads = lookaheads
        if sum(lookaheads) > 0:
            assert self.unidir
        self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks
        self.chunk_size_current = int(
            chunk_size_current.split('_')[-1]) // n_stacks
        self.chunk_size_right = int(
            chunk_size_right.split('_')[-1]) // n_stacks
        self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type
        self.cnn_lookahead = self.unidir or enc_type == 'conv'
        self.streaming_type = streaming_type if self.lc_bidir else ''
        # -: past context
        # *: current context
        # +: future context
        # reshape) overlapped windowing. additional redundant computation is introduced.
        # During inference, caching is not applied. However, considering (N_l+N_c+N_r) is very short
        # and independent on layer depth, the overhead is negligible.
        # chunk1: |**|++
        # chunk2:  --|**|++
        # chunk3:     --|**|++
        # chunk4:        --|**|++
        # chunk5:           --|**|++
        # mask) chunkwise masking. future context is restricted within the current chunk
        # to avoid accumuration of future context depending on the layer depth.
        # chunk1: |**|
        # chunk2:  --|**|
        # chunk3:  -- --|**|
        # chunk4:     -- --|**|
        # chunk5:        -- --|**|
        if self.unidir:
            assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0
        if self.streaming_type == 'mask':
            assert self.chunk_size_right == 0
            assert self.chunk_size_left == self.chunk_size_current
            # NOTE: this is important to cache CNN output at each chunk
        if self.lc_bidir:
            assert n_layers_sub1 == 0
            assert n_layers_sub2 == 0
            assert not self.unidir

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        self.subsample = None
        if np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
            if subsample_type == 'max_pool':
                self.subsample = nn.ModuleList(
                    [MaxpoolSubsampler(factor) for factor in subsamples])
            elif subsample_type == 'concat':
                self.subsample = nn.ModuleList([
                    ConcatSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'drop':
                self.subsample = nn.ModuleList(
                    [DropSubsampler(factor) for factor in subsamples])
            elif subsample_type == '1dconv':
                self.subsample = nn.ModuleList([
                    Conv1dSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'add':
                self.subsample = nn.ModuleList(
                    [AddSubsampler(factor) for factor in subsamples])

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.pos_enc, self.pos_emb = None, None
        self.u_bias, self.v_bias = None, None
        if pe_type in ['relative', 'relative_xl']:
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
            if pe_type == 'relative_xl':
                self.u_bias = nn.Parameter(
                    torch.Tensor(n_heads, d_model // n_heads))
                self.v_bias = nn.Parameter(
                    torch.Tensor(n_heads, d_model // n_heads))
                # NOTE: u_bias and v_bias are global parameters shared in the whole model
        else:
            self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type,
                                              param_init)

        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model, d_ff, n_heads, dropout,
                                        dropout_att, dropout_layer,
                                        layer_norm_eps, ffn_activation,
                                        param_init, pe_type, clamp_len,
                                        ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, clamp_len, ffn_bottleneck_dim)
            odim_sub1 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)
                odim_sub1 = last_proj_dim
            if n_layers_sub1 == n_layers:
                self.norm_out_sub1 = None
            else:
                self.norm_out_sub1 = nn.LayerNorm(odim_sub1,
                                                  eps=layer_norm_eps)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, clamp_len, ffn_bottleneck_dim)
            odim_sub2 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)
                odim_sub2 = last_proj_dim
            if n_layers_sub2 == n_layers:
                self.norm_out_sub2 = None
            else:
                self.norm_out_sub2 = nn.LayerNorm(odim_sub2,
                                                  eps=layer_norm_eps)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)

        # for streaming inference
        self.reset_cache()
Пример #2
0
    def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1,
                 n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim,
                 ffn_activation, pe_type, layer_norm_eps, last_proj_dim,
                 dropout_in, dropout, dropout_att, dropout_layer, subsample,
                 subsample_type, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init, clamp_len,
                 lookahead, chunk_size_left, chunk_size_current,
                 chunk_size_right, streaming_type):

        super(TransformerEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s
        # parse lookahead
        lookaheads = [0] * n_layers
        for lth, s in enumerate(list(map(int,
                                         lookahead.split('_')[:n_layers]))):
            lookaheads[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)
        self.unidir = 'uni' in enc_type

        # for compatibility
        chunk_size_left = str(chunk_size_left)
        chunk_size_current = str(chunk_size_current)
        chunk_size_right = str(chunk_size_right)

        # for streaming encoder
        self.lookaheads = lookaheads
        if sum(lookaheads) > 0:
            assert self.unidir
        self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks
        self.chunk_size_current = int(
            chunk_size_current.split('_')[-1]) // n_stacks
        self.chunk_size_right = int(
            chunk_size_right.split('_')[-1]) // n_stacks
        self.lc_bidir = self.chunk_size_left > 0 or self.chunk_size_current > 0 or self.chunk_size_right > 0
        self.streaming_type = streaming_type
        # reshape) not lookahead frames in CNN layers, but requires some additional computations
        # mask) there are some lookahead frames in CNN layers, no additional computations
        if self.lc_bidir:
            assert n_layers_sub1 == 0
            assert n_layers_sub2 == 0
            assert not self.unidir

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        self.subsample = None
        if np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
            if subsample_type == 'max_pool':
                self.subsample = nn.ModuleList(
                    [MaxpoolSubsampler(factor) for factor in subsamples])
            elif subsample_type == 'concat':
                self.subsample = nn.ModuleList([
                    ConcatSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'drop':
                self.subsample = nn.ModuleList(
                    [DropSubsampler(factor) for factor in subsamples])
            elif subsample_type == '1dconv':
                self.subsample = nn.ModuleList([
                    Conv1dSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'add':
                self.subsample = nn.ModuleList(
                    [AddSubsampler(factor) for factor in subsamples])

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.clamp_len = clamp_len
        self.u_bias, self.v_bias = None, None
        self.pos_emb = None
        if pe_type == 'relative_xl':
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
            self.u_bias = nn.Parameter(
                torch.Tensor(n_heads, d_model // n_heads))
            self.v_bias = nn.Parameter(
                torch.Tensor(n_heads, d_model // n_heads))
            # NOTE: u_bias and v_bias are global parameters
        elif pe_type == 'relative':
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
        else:
            self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type,
                                              param_init)

        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model, d_ff, n_heads, dropout,
                                        dropout_att, dropout_layer,
                                        layer_norm_eps, ffn_activation,
                                        param_init, pe_type,
                                        ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, ffn_bottleneck_dim)
            odim_sub1 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)
                odim_sub1 = last_proj_dim
            if n_layers_sub1 == n_layers:
                self.norm_out_sub1 = None
            else:
                self.norm_out_sub1 = nn.LayerNorm(odim_sub1,
                                                  eps=layer_norm_eps)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, ffn_bottleneck_dim)
            odim_sub2 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)
                odim_sub2 = last_proj_dim
            if n_layers_sub2 == n_layers:
                self.norm_out_sub2 = None
            else:
                self.norm_out_sub2 = nn.LayerNorm(odim_sub2,
                                                  eps=layer_norm_eps)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)