def __init__(
            self,
            input_size,
            encoder_type,
            encoder_bidirectional,
            encoder_num_units,
            encoder_num_proj,
            encoder_num_layers,
            encoder_num_layers_sub,  # ***
            fc_list,
            fc_list_sub,
            dropout_input,
            dropout_encoder,
            main_loss_weight,  # ***
            sub_loss_weight,  # ***
            num_classes,
            num_classes_sub,  # ***
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[],
            subsample_type='drop',
            logits_temperature=1,
            num_stack=1,
            splice=1,
            input_channel=1,
            conv_channels=[],
            conv_kernel_sizes=[],
            conv_strides=[],
            poolings=[],
            activation='relu',
            batch_norm=False,
            label_smoothing_prob=0,
            weight_noise_std=0,
            encoder_residual=False,
            encoder_dense_residual=False):

        super(HierarchicalCTC,
              self).__init__(input_size=input_size,
                             encoder_type=encoder_type,
                             encoder_bidirectional=encoder_bidirectional,
                             encoder_num_units=encoder_num_units,
                             encoder_num_proj=encoder_num_proj,
                             encoder_num_layers=encoder_num_layers,
                             dropout_input=dropout_input,
                             dropout_encoder=dropout_encoder,
                             num_classes=num_classes,
                             parameter_init=parameter_init,
                             subsample_list=subsample_list,
                             subsample_type=subsample_type,
                             fc_list=fc_list,
                             num_stack=num_stack,
                             splice=splice,
                             input_channel=input_channel,
                             conv_channels=conv_channels,
                             conv_kernel_sizes=conv_kernel_sizes,
                             conv_strides=conv_strides,
                             poolings=poolings,
                             logits_temperature=logits_temperature,
                             batch_norm=batch_norm,
                             label_smoothing_prob=label_smoothing_prob,
                             weight_noise_std=weight_noise_std)
        self.model_type = 'hierarchical_ctc'

        # Setting for the encoder
        self.encoder_num_layers_sub = encoder_num_layers_sub
        self.fc_list_sub = fc_list_sub

        # Setting for CTC
        self.num_classes_sub = num_classes_sub + 1  # Add the blank class

        # Setting for MTL
        self.main_loss_weight = main_loss_weight
        self.sub_loss_weight = sub_loss_weight

        # Load the encoder
        # NOTE: overide encoder
        if encoder_type in ['lstm', 'gru', 'rnn']:
            self.encoder = load(encoder_type=encoder_type)(
                input_size=input_size,
                rnn_type=encoder_type,
                bidirectional=encoder_bidirectional,
                num_units=encoder_num_units,
                num_proj=encoder_num_proj,
                num_layers=encoder_num_layers,
                num_layers_sub=encoder_num_layers_sub,
                dropout_input=dropout_input,
                dropout_hidden=dropout_encoder,
                subsample_list=subsample_list,
                subsample_type=subsample_type,
                batch_first=True,
                merge_bidirectional=False,
                pack_sequence=True,
                num_stack=num_stack,
                splice=splice,
                input_channel=input_channel,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                activation=activation,
                batch_norm=batch_norm,
                residual=encoder_residual,
                dense_residual=encoder_dense_residual)
        elif encoder_type == 'cnn':
            assert num_stack == 1 and splice == 1
            self.encoder = load(encoder_type='cnn')(
                input_size=input_size,
                input_channel=input_channel,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                dropout_input=dropout_input,
                dropout_hidden=dropout_encoder,
                activation=activation,
                batch_norm=batch_norm)
        else:
            raise NotImplementedError

        ##################################################
        # Fully-connected layers in the main task
        ##################################################
        if len(fc_list) > 0:
            for i in range(len(fc_list)):
                if i == 0:
                    if encoder_type == 'cnn':
                        bottle_input_size = self.encoder.output_size
                    else:
                        bottle_input_size = self.encoder_num_units

                    # TODO: add batch norm layers

                    setattr(
                        self, 'fc_0',
                        LinearND(bottle_input_size,
                                 fc_list[i],
                                 dropout=dropout_encoder))
                else:
                    # TODO: add batch norm layers

                    setattr(
                        self, 'fc_' + str(i),
                        LinearND(fc_list[i - 1],
                                 fc_list[i],
                                 dropout=dropout_encoder))
            # TODO: remove a bias term in the case of batch normalization

            self.fc_out = LinearND(fc_list[-1], self.num_classes)
        else:
            self.fc_out = LinearND(self.encoder_num_units, self.num_classes)

        ##################################################
        # Fully-connected layers in the sub task
        ##################################################
        if len(fc_list_sub) > 0:
            for i in range(len(fc_list_sub)):
                if i == 0:
                    if encoder_type == 'cnn':
                        bottle_input_size = self.encoder.output_size
                    else:
                        bottle_input_size = self.encoder_num_units

                    # TODO: add batch norm layers

                    setattr(
                        self, 'fc_sub_0',
                        LinearND(bottle_input_size,
                                 fc_list_sub[i],
                                 dropout=dropout_encoder))
                else:
                    # TODO: add batch norm layers

                    setattr(
                        self, 'fc_sub_' + str(i),
                        LinearND(fc_list_sub[i - 1],
                                 fc_list_sub[i],
                                 dropout=dropout_encoder))
            # TODO: remove a bias term in the case of batch normalization

            self.fc_out_sub = LinearND(fc_list_sub[-1], self.num_classes_sub)
        else:
            self.fc_out_sub = LinearND(self.encoder_num_units,
                                       self.num_classes_sub)

        ##################################################
        # Initialize parameters
        ##################################################
        self.init_weights(parameter_init,
                          distribution=parameter_init_distribution,
                          ignore_keys=['bias'])

        # Initialize all biases with 0
        self.init_weights(0, distribution='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if recurrent_weight_orthogonal:
            self.init_weights(parameter_init,
                              distribution='orthogonal',
                              keys=['lstm', 'weight'],
                              ignore_keys=['bias'])

        # Initialize bias in forget gate with 1
        if init_forget_gate_bias_with_one:
            self.init_forget_gate_bias_with_one()
    def __init__(self,
                 encoder_num_units,
                 decoder_num_units,
                 attention_type,
                 attention_dim,
                 sharpening_factor=1,
                 sigmoid_smoothing=False,
                 out_channels=10,
                 kernel_size=201,
                 num_heads=1):

        super(AttentionMechanism, self).__init__()

        self.attention_type = attention_type
        self.attention_dim = attention_dim
        self.sharpening_factor = sharpening_factor
        self.sigmoid_smoothing = sigmoid_smoothing
        self.num_heads = num_heads

        # Multi-head attention
        if num_heads > 1:
            setattr(self, 'W_mha',
                    LinearND(encoder_num_units * num_heads, encoder_num_units))

        for h in range(num_heads):
            if self.attention_type == 'content':
                setattr(self, 'W_enc_head' + str(h),
                        LinearND(encoder_num_units, attention_dim, bias=True))
                setattr(self, 'W_dec_head' + str(h),
                        LinearND(decoder_num_units, attention_dim, bias=False))
                setattr(self, 'V_head' + str(h),
                        LinearND(attention_dim, 1, bias=False))

            elif self.attention_type == 'location':
                assert kernel_size % 2 == 1

                setattr(self, 'W_enc_head' + str(h),
                        LinearND(encoder_num_units, attention_dim, bias=True))
                setattr(self, 'W_dec_head' + str(h),
                        LinearND(decoder_num_units, attention_dim, bias=False))
                setattr(self, 'W_conv_head' + str(h),
                        LinearND(out_channels, attention_dim, bias=False))
                # setattr(self, 'conv_head' + str(h),
                #         nn.Conv1d(in_channels=1,
                #                   out_channels=out_channels,
                #                   kernel_size=kernel_size,
                #                   stride=1,
                #                   padding=kernel_size // 2,
                #                   bias=False))
                setattr(
                    self, 'conv_head' + str(h),
                    nn.Conv2d(in_channels=1,
                              out_channels=out_channels,
                              kernel_size=(1, kernel_size),
                              stride=1,
                              padding=(0, kernel_size // 2),
                              bias=False))
                setattr(self, 'V_head' + str(h),
                        LinearND(attention_dim, 1, bias=False))

            elif self.attention_type == 'dot_product':
                setattr(
                    self, 'W_enc_head' + str(h),
                    LinearND(encoder_num_units, decoder_num_units, bias=False))

            elif self.attention_type == 'rnn_attention':
                raise NotImplementedError

            elif self.attention_type == 'coverage':
                setattr(self, 'W_enc_head' + str(h),
                        LinearND(encoder_num_units, attention_dim, bias=True))
                setattr(self, 'W_dec_head' + str(h),
                        LinearND(decoder_num_units, attention_dim, bias=False))
                setattr(self, 'W_cov_head' + str(h),
                        LinearND(encoder_num_units, attention_dim, bias=False))
                setattr(self, 'V_head' + str(h),
                        LinearND(attention_dim, 1, bias=False))
                self.aw_cumsum = None

            else:
                raise TypeError(
                    "attention_type should be one of [%s], you provided %s." %
                    (", ".join(ATTENTION_TYPE), attention_type))
Ejemplo n.º 3
0
    def __init__(self,
                 input_size,
                 encoder_type,
                 encoder_bidirectional,
                 encoder_num_units,
                 encoder_num_proj,
                 encoder_num_layers,
                 fc_list,
                 dropout_input,
                 dropout_encoder,
                 num_classes,
                 parameter_init_distribution='uniform',
                 parameter_init=0.1,
                 recurrent_weight_orthogonal=False,
                 init_forget_gate_bias_with_one=True,
                 subsample_list=[],
                 subsample_type='drop',
                 logits_temperature=1,
                 num_stack=1,
                 splice=1,
                 input_channel=1,
                 conv_channels=[],
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 poolings=[],
                 activation='relu',
                 batch_norm=False,
                 label_smoothing_prob=0,
                 weight_noise_std=0,
                 encoder_residual=False,
                 encoder_dense_residual=False):

        super(ModelBase, self).__init__()
        self.model_type = 'ctc'

        # Setting for the encoder
        self.input_size = input_size
        self.num_stack = num_stack
        self.encoder_type = encoder_type
        self.encoder_num_units = encoder_num_units
        if encoder_bidirectional:
            self.encoder_num_units *= 2
        self.fc_list = fc_list
        self.subsample_list = subsample_list

        # Setting for CTC
        self.num_classes = num_classes + 1  # Add the blank class
        self.logits_temperature = logits_temperature

        # Setting for regualarization
        self.weight_noise_injection = False
        self.weight_noise_std = float(weight_noise_std)
        self.ls_prob = label_smoothing_prob

        # Call the encoder function
        if encoder_type in ['lstm', 'gru', 'rnn']:
            self.encoder = load(encoder_type=encoder_type)(
                input_size=input_size,
                rnn_type=encoder_type,
                bidirectional=encoder_bidirectional,
                num_units=encoder_num_units,
                num_proj=encoder_num_proj,
                num_layers=encoder_num_layers,
                dropout_input=dropout_input,
                dropout_hidden=dropout_encoder,
                subsample_list=subsample_list,
                subsample_type=subsample_type,
                batch_first=True,
                merge_bidirectional=False,
                pack_sequence=True,
                num_stack=num_stack,
                splice=splice,
                input_channel=input_channel,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                activation=activation,
                batch_norm=batch_norm,
                residual=encoder_residual,
                dense_residual=encoder_dense_residual,
                nin=0)
        elif encoder_type == 'cnn':
            assert num_stack == 1 and splice == 1
            self.encoder = load(encoder_type='cnn')(
                input_size=input_size,
                input_channel=input_channel,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                dropout_input=dropout_input,
                dropout_hidden=dropout_encoder,
                activation=activation,
                batch_norm=batch_norm)
        else:
            raise NotImplementedError

        ##################################################
        # Fully-connected layers
        ##################################################
        if len(fc_list) > 0:
            for i in range(len(fc_list)):
                if i == 0:
                    if encoder_type == 'cnn':
                        bottle_input_size = self.encoder.output_size
                    else:
                        bottle_input_size = self.encoder_num_units

                    # if batch_norm:
                    #     setattr(self, 'bn_fc_0', nn.BatchNorm1d(
                    #         bottle_input_size))

                    setattr(
                        self, 'fc_0',
                        LinearND(bottle_input_size,
                                 fc_list[i],
                                 dropout=dropout_encoder))
                else:
                    # if batch_norm:
                    #     setattr(self, 'fc_bn_' + str(i),
                    #             nn.BatchNorm1d(fc_list[i - 1]))

                    setattr(
                        self, 'fc_' + str(i),
                        LinearND(fc_list[i - 1],
                                 fc_list[i],
                                 dropout=dropout_encoder))
            # TODO: remove a bias term in the case of batch normalization

            self.fc_out = LinearND(fc_list[-1], self.num_classes)
        else:
            self.fc_out = LinearND(self.encoder_num_units, self.num_classes)

        ##################################################
        # Initialize parameters
        ##################################################
        self.init_weights(parameter_init,
                          distribution=parameter_init_distribution,
                          ignore_keys=['bias'])

        # Initialize all biases with 0
        self.init_weights(0, distribution='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if recurrent_weight_orthogonal and encoder_type != 'cnn':
            self.init_weights(parameter_init,
                              distribution='orthogonal',
                              keys=[encoder_type, 'weight'],
                              ignore_keys=['bias'])

        # Initialize bias in forget gate with 1
        if init_forget_gate_bias_with_one:
            self.init_forget_gate_bias_with_one()

        # Set CTC decoders
        self._decode_greedy_np = GreedyDecoder(blank_index=0)
        self._decode_beam_np = BeamSearchDecoder(blank_index=0)
    def __init__(self,
                 input_size,
                 rnn_type,
                 bidirectional,
                 num_units,
                 num_proj,
                 num_layers,
                 dropout_input,
                 dropout_hidden,
                 subsample_list=[],
                 subsample_type='drop',
                 use_cuda=False,
                 batch_first=False,
                 merge_bidirectional=False,
                 pack_sequence=True,
                 num_stack=1,
                 splice=1,
                 input_channel=1,
                 conv_channels=[],
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 poolings=[],
                 activation='relu',
                 batch_norm=False,
                 residual=False,
                 dense_residual=False,
                 num_layers_sub=0,
                 nin=0):

        super(RNNEncoder, self).__init__()

        if len(subsample_list) > 0 and len(subsample_list) != num_layers:
            raise ValueError(
                'subsample_list must be the same size as num_layers.')
        if subsample_type not in ['drop', 'concat']:
            raise TypeError('subsample_type must be "drop" or "concat".')
        if num_layers_sub < 0 or (num_layers_sub > 1
                                  and num_layers < num_layers_sub):
            raise ValueError('Set num_layers_sub between 1 to num_layers.')

        self.rnn_type = rnn_type
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        self.num_units = num_units
        self.num_proj = num_proj if num_proj is not None else 0
        self.num_layers = num_layers
        self.use_cuda = use_cuda
        self.batch_first = batch_first
        self.merge_bidirectional = merge_bidirectional
        self.pack_sequence = pack_sequence

        # Setting for hierarchical encoder
        self.num_layers_sub = num_layers_sub

        # Setting for subsampling
        if len(subsample_list) == 0:
            self.subsample_list = [False] * num_layers
        else:
            self.subsample_list = subsample_list
        self.subsample_type = subsample_type
        # This implementation is bases on
        # https://arxiv.org/abs/1508.01211
        #     Chan, William, et al. "Listen, attend and spell."
        #         arXiv preprint arXiv:1508.01211 (2015).

        # Setting for residual connection
        assert not (residual and dense_residual)
        self.residual = residual
        self.dense_residual = dense_residual
        subsample_last_layer = 0
        for l_reverse, is_subsample in enumerate(subsample_list[::-1]):
            if is_subsample:
                subsample_last_layer = num_layers - l_reverse
                break
        self.residual_start_layer = subsample_last_layer + 1
        # NOTE: residual connection starts from the last subsampling layer

        # Setting for the NiN
        self.batch_norm = batch_norm
        self.nin = nin

        # Dropout for input-hidden connection
        self.dropout_input = nn.Dropout(p=dropout_input)

        # Setting for CNNs before RNNs
        if len(conv_channels) > 0 and len(conv_channels) == len(
                conv_kernel_sizes) and len(conv_kernel_sizes) == len(
                    conv_strides):
            assert num_stack == 1 and splice == 1
            self.conv = CNNEncoder(input_size,
                                   input_channel=input_channel,
                                   conv_channels=conv_channels,
                                   conv_kernel_sizes=conv_kernel_sizes,
                                   conv_strides=conv_strides,
                                   poolings=poolings,
                                   dropout_input=0,
                                   dropout_hidden=dropout_hidden,
                                   activation=activation,
                                   batch_norm=batch_norm)
            input_size = self.conv.output_size
        else:
            input_size = input_size * splice * num_stack
            self.conv = None

        # Fast implementation without using torch.nn.utils.rnn.PackedSequence
        if sum(
                self.subsample_list
        ) == 0 and self.num_proj == 0 and not residual and not dense_residual and num_layers_sub == 0 and (
                not batch_norm) and nin == 0:
            self.fast_impl = True

            if rnn_type == 'lstm':
                rnn = nn.LSTM(input_size,
                              hidden_size=num_units,
                              num_layers=num_layers,
                              bias=True,
                              batch_first=batch_first,
                              dropout=dropout_hidden,
                              bidirectional=bidirectional)
            elif rnn_type == 'gru':
                rnn = nn.GRU(input_size,
                             hidden_size=num_units,
                             num_layers=num_layers,
                             bias=True,
                             batch_first=batch_first,
                             dropout=dropout_hidden,
                             bidirectional=bidirectional)
            elif rnn_type == 'rnn':
                rnn = nn.RNN(input_size,
                             hidden_size=num_units,
                             num_layers=num_layers,
                             bias=True,
                             batch_first=batch_first,
                             dropout=dropout_hidden,
                             bidirectional=bidirectional)
            else:
                raise ValueError('rnn_type must be "lstm" or "gru" or "rnn".')

            setattr(self, rnn_type, rnn)
            # NOTE: pytorch introduces a dropout layer on the outputs of
            # each RNN layer EXCEPT the last layer

            # Dropout for hidden-output connection
            self.dropout_last = nn.Dropout(p=dropout_hidden)

        else:
            self.fast_impl = False

            for l in range(num_layers):
                if l == 0:
                    encoder_input_size = input_size
                elif nin > 0:
                    encoder_input_size = nin
                elif self.num_proj > 0:
                    encoder_input_size = num_proj
                    if subsample_type == 'concat' and l > 0 and self.subsample_list[
                            l - 1]:
                        encoder_input_size *= 2
                else:
                    encoder_input_size = num_units * self.num_directions
                    if subsample_type == 'concat' and l > 0 and self.subsample_list[
                            l - 1]:
                        encoder_input_size *= 2

                if rnn_type == 'lstm':
                    rnn_i = nn.LSTM(encoder_input_size,
                                    hidden_size=num_units,
                                    num_layers=1,
                                    bias=True,
                                    batch_first=batch_first,
                                    dropout=0,
                                    bidirectional=bidirectional)

                elif rnn_type == 'gru':
                    rnn_i = nn.GRU(encoder_input_size,
                                   hidden_size=num_units,
                                   num_layers=1,
                                   bias=True,
                                   batch_first=batch_first,
                                   dropout=0,
                                   bidirectional=bidirectional)
                elif rnn_type == 'rnn':
                    rnn_i = nn.RNN(encoder_input_size,
                                   hidden_size=num_units,
                                   num_layers=1,
                                   bias=True,
                                   batch_first=batch_first,
                                   dropout=0,
                                   bidirectional=bidirectional)
                else:
                    raise ValueError(
                        'rnn_type must be "lstm" or "gru" or "rnn".')

                setattr(self, rnn_type + '_l' + str(l), rnn_i)
                encoder_output_size = num_units * self.num_directions

                # Dropout for hidden-hidden or hidden-output connection
                setattr(self, 'dropout_l' + str(l),
                        nn.Dropout(p=dropout_hidden))

                if l != self.num_layers - 1 and self.num_proj > 0:
                    proj_i = LinearND(num_units * self.num_directions,
                                      num_proj,
                                      dropout=dropout_hidden)
                    setattr(self, 'proj_l' + str(l), proj_i)
                    encoder_output_size = num_proj

                # Network in network (1*1 conv)
                if nin > 0:
                    setattr(
                        self, 'nin_l' + str(l),
                        nn.Conv1d(in_channels=encoder_output_size,
                                  out_channels=nin,
                                  kernel_size=1,
                                  stride=1,
                                  padding=1,
                                  bias=not batch_norm))

                    # Batch normalization
                    if batch_norm:
                        if nin:
                            setattr(self, 'bn_0_l' + str(l),
                                    nn.BatchNorm1d(encoder_output_size))
                            setattr(self, 'bn_l' + str(l), nn.BatchNorm1d(nin))
                        elif subsample_type == 'concat' and self.subsample_list[
                                l]:
                            setattr(self, 'bn_l' + str(l),
                                    nn.BatchNorm1d(encoder_output_size * 2))
                        else:
                            setattr(self, 'bn_l' + str(l),
                                    nn.BatchNorm1d(encoder_output_size))
Ejemplo n.º 5
0
    def __init__(self,
                 num_classes,
                 embedding_dim,
                 rnn_type,
                 bidirectional,
                 num_units,
                 num_layers,
                 dropout_embedding,
                 dropout_hidden,
                 dropout_output,
                 parameter_init_distribution='uniform',
                 parameter_init=0.1,
                 tie_weights=False,
                 init_forget_gate_bias_with_one=True):

        super(ModelBase, self).__init__()
        self.model_type = 'rnnlm'

        # TODO: clip_activation

        self.embedding_dim = embedding_dim
        self.rnn_type = rnn_type
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1
        self.num_units = num_units
        self.num_layers = num_layers
        self.parameter_init = parameter_init
        self.tie_weights = tie_weights

        self.num_classes = num_classes + 1  # Add <EOS> class
        # self.padded_index = 0
        self.padded_index = -1

        self.embed = Embedding(num_classes=self.num_classes,
                               embedding_dim=embedding_dim,
                               dropout=dropout_embedding,
                               ignore_index=self.padded_index)
        # NOTE: share the embedding layer between inputs and outputs

        if rnn_type == 'lstm':
            rnn = nn.LSTM(embedding_dim,
                          hidden_size=num_units,
                          num_layers=num_layers,
                          bias=True,
                          batch_first=True,
                          dropout=dropout_hidden,
                          bidirectional=bidirectional)
        elif rnn_type == 'gru':
            rnn = nn.GRU(embedding_dim,
                         hidden_size=num_units,
                         num_layers=num_layers,
                         bias=True,
                         batch_first=True,
                         dropout=dropout_hidden,
                         bidirectional=bidirectional)
        elif rnn_type == 'rnn':
            rnn = nn.RNN(
                embedding_dim,
                hidden_size=num_units,
                num_layers=num_layers,
                nonlinearity='tanh',
                # nonlinearity='relu',
                bias=True,
                batch_first=True,
                dropout=dropout_hidden,
                bidirectional=bidirectional)
        setattr(self, rnn_type, rnn)

        self.output = LinearND(num_units * self.num_directions,
                               self.num_classes,
                               dropout=dropout_output)

        # Optionally tie weights as in:
        # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
        # https://arxiv.org/abs/1608.05859
        # and
        # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
        # https://arxiv.org/abs/1611.01462
        if tie_weights:
            if num_units != embedding_dim:
                raise ValueError(
                    'When using the tied flag, num_units must be equal to embedding_dim'
                )
            self.output.fc.weight = self.embed.embed.weight

        ##################################################
        # Initialize parameters
        ##################################################
        self.init_weights(parameter_init,
                          distribution=parameter_init_distribution,
                          ignore_keys=['bias'])

        # Initialize all biases with 0
        self.init_weights(0, distribution='constant', keys=['bias'])

        # Initialize bias in forget gate with 1
        if init_forget_gate_bias_with_one:
            self.init_forget_gate_bias_with_one()
Ejemplo n.º 6
0
    def __init__(
            self,
            input_size,
            encoder_type,
            encoder_bidirectional,
            encoder_num_units,
            encoder_num_proj,
            encoder_num_layers,
            encoder_num_layers_sub,  # ***
            attention_type,
            attention_dim,
            decoder_type,
            decoder_num_units,
            decoder_num_units_sub,  # ***
            decoder_num_layers,
            decoder_num_layers_sub,  # ***
            embedding_dim,
            embedding_dim_sub,  # ***
            dropout_input,
            dropout_encoder,
            dropout_decoder,
            dropout_embedding,
            main_loss_weight,  # ***
            sub_loss_weight,  # ***
            num_classes,
            num_classes_sub,  # ***
            parameter_init_distribution='uniform',
            parameter_init=0.1,
            recurrent_weight_orthogonal=False,
            init_forget_gate_bias_with_one=True,
            subsample_list=[],
            subsample_type='drop',
            bridge_layer=False,
            init_dec_state='first',
            sharpening_factor=1,
            logits_temperature=1,
            sigmoid_smoothing=False,
            coverage_weight=0,
            ctc_loss_weight_sub=0,  # ***
            attention_conv_num_channels=10,
            attention_conv_width=201,
            num_stack=1,
            splice=1,
            input_channel=1,
            conv_channels=[],
            conv_kernel_sizes=[],
            conv_strides=[],
            poolings=[],
            activation='relu',
            batch_norm=False,
            scheduled_sampling_prob=0,
            scheduled_sampling_max_step=0,
            label_smoothing_prob=0,
            weight_noise_std=0,
            encoder_residual=False,
            encoder_dense_residual=False,
            decoder_residual=False,
            decoder_dense_residual=False,
            decoding_order='attend_generate_update',
            bottleneck_dim=256,
            bottleneck_dim_sub=256,  # ***
            backward_sub=False,  # ***
            num_heads=1,
            num_heads_sub=1):  # ***

        super(HierarchicalAttentionSeq2seq, self).__init__(
            input_size=input_size,
            encoder_type=encoder_type,
            encoder_bidirectional=encoder_bidirectional,
            encoder_num_units=encoder_num_units,
            encoder_num_proj=encoder_num_proj,
            encoder_num_layers=encoder_num_layers,
            attention_type=attention_type,
            attention_dim=attention_dim,
            decoder_type=decoder_type,
            decoder_num_units=decoder_num_units,
            decoder_num_layers=decoder_num_layers,
            embedding_dim=embedding_dim,
            dropout_input=dropout_input,
            dropout_encoder=dropout_encoder,
            dropout_decoder=dropout_decoder,
            dropout_embedding=dropout_embedding,
            num_classes=num_classes,
            parameter_init=parameter_init,
            subsample_list=subsample_list,
            subsample_type=subsample_type,
            bridge_layer=bridge_layer,
            init_dec_state=init_dec_state,
            sharpening_factor=sharpening_factor,
            logits_temperature=logits_temperature,
            sigmoid_smoothing=sigmoid_smoothing,
            coverage_weight=coverage_weight,
            ctc_loss_weight=0,
            attention_conv_num_channels=attention_conv_num_channels,
            attention_conv_width=attention_conv_width,
            num_stack=num_stack,
            splice=splice,
            input_channel=input_channel,
            conv_channels=conv_channels,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            poolings=poolings,
            scheduled_sampling_prob=scheduled_sampling_prob,
            scheduled_sampling_max_step=scheduled_sampling_max_step,
            label_smoothing_prob=label_smoothing_prob,
            weight_noise_std=weight_noise_std,
            encoder_residual=encoder_residual,
            encoder_dense_residual=encoder_dense_residual,
            decoder_residual=decoder_residual,
            decoder_dense_residual=decoder_dense_residual,
            decoding_order=decoding_order,
            bottleneck_dim=bottleneck_dim,
            backward_loss_weight=0,
            num_heads=num_heads)
        self.model_type = 'hierarchical_attention'

        # Setting for the encoder
        self.encoder_num_units_sub = encoder_num_units
        if encoder_bidirectional:
            self.encoder_num_units_sub *= 2

        # Setting for the decoder in the sub task
        self.decoder_num_units_1 = decoder_num_units_sub
        self.decoder_num_layers_1 = decoder_num_layers_sub
        self.num_classes_sub = num_classes_sub + 1  # Add <EOS> class
        self.sos_1 = num_classes_sub
        self.eos_1 = num_classes_sub
        # NOTE: <SOS> and <EOS> have the same index
        self.backward_1 = backward_sub

        # Setting for the decoder initialization in the sub task
        if backward_sub:
            if init_dec_state == 'first':
                self.init_dec_state_1_bwd = 'final'
            elif init_dec_state == 'final':
                self.init_dec_state_1_bwd = 'first'
            else:
                self.init_dec_state_1_bwd = init_dec_state
            if encoder_type != decoder_type:
                self.init_dec_state_1_bwd = 'zero'
        else:
            self.init_dec_state_1_fwd = init_dec_state
            if encoder_type != decoder_type:
                self.init_dec_state_1_fwd = 'zero'

        # Setting for the attention in the sub task
        self.num_heads_1 = num_heads_sub

        # Setting for MTL
        self.main_loss_weight = main_loss_weight
        self.sub_loss_weight = sub_loss_weight
        self.ctc_loss_weight_sub = ctc_loss_weight_sub
        if backward_sub:
            self.bwd_weight_1 = sub_loss_weight

        ##############################
        # Encoder
        # NOTE: overide encoder
        ##############################
        if encoder_type in ['lstm', 'gru', 'rnn']:
            self.encoder = load(encoder_type=encoder_type)(
                input_size=input_size,
                rnn_type=encoder_type,
                bidirectional=encoder_bidirectional,
                num_units=encoder_num_units,
                num_proj=encoder_num_proj,
                num_layers=encoder_num_layers,
                num_layers_sub=encoder_num_layers_sub,
                dropout_input=dropout_input,
                dropout_hidden=dropout_encoder,
                subsample_list=subsample_list,
                subsample_type=subsample_type,
                batch_first=True,
                merge_bidirectional=False,
                pack_sequence=True,
                num_stack=num_stack,
                splice=splice,
                input_channel=input_channel,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                activation=activation,
                batch_norm=batch_norm,
                residual=encoder_residual,
                dense_residual=encoder_dense_residual)
        elif encoder_type == 'cnn':
            assert num_stack == 1 and splice == 1
            self.encoder = load(encoder_type='cnn')(
                input_size=input_size,
                input_channel=input_channel,
                conv_channels=conv_channels,
                conv_kernel_sizes=conv_kernel_sizes,
                conv_strides=conv_strides,
                poolings=poolings,
                dropout_input=dropout_input,
                dropout_hidden=dropout_encoder,
                activation=activation,
                batch_norm=batch_norm)
            self.init_dec_state_0 = 'zero'
            self.init_dec_state_1 = 'zero'
        else:
            raise NotImplementedError

        dir = 'bwd' if backward_sub else 'fwd'
        self.is_bridge_sub = False
        if self.sub_loss_weight > 0:
            ##################################################
            # Bridge layer between the encoder and decoder
            ##################################################
            if encoder_type == 'cnn':
                self.bridge_1 = LinearND(self.encoder.output_size,
                                         decoder_num_units_sub,
                                         dropout=dropout_encoder)
                self.encoder_num_units_sub = decoder_num_units_sub
                self.is_bridge_sub = True
            elif bridge_layer:
                self.bridge_1 = LinearND(self.encoder_num_units_sub,
                                         decoder_num_units_sub,
                                         dropout=dropout_encoder)
                self.encoder_num_units_sub = decoder_num_units_sub
                self.is_bridge_sub = True
            else:
                self.is_bridge_sub = False

            ##################################################
            # Initialization of the decoder
            ##################################################
            if getattr(self, 'init_dec_state_1_' + dir) != 'zero':
                setattr(
                    self, 'W_dec_init_1_' + dir,
                    LinearND(self.encoder_num_units_sub,
                             decoder_num_units_sub))

            ##############################
            # Decoder (sub)
            ##############################
            if decoding_order == 'conditional':
                setattr(
                    self, 'decoder_first_1_' + dir,
                    RNNDecoder(input_size=embedding_dim_sub,
                               rnn_type=decoder_type,
                               num_units=decoder_num_units_sub,
                               num_layers=1,
                               dropout=dropout_decoder,
                               residual=False,
                               dense_residual=False))
                setattr(
                    self, 'decoder_second_1_' + dir,
                    RNNDecoder(input_size=self.encoder_num_units_sub,
                               rnn_type=decoder_type,
                               num_units=decoder_num_units_sub,
                               num_layers=1,
                               dropout=dropout_decoder,
                               residual=False,
                               dense_residual=False))
                # NOTE; the conditional decoder only supports the 1 layer
            else:
                setattr(
                    self, 'decoder_1_' + dir,
                    RNNDecoder(input_size=self.encoder_num_units_sub +
                               embedding_dim_sub,
                               rnn_type=decoder_type,
                               num_units=decoder_num_units_sub,
                               num_layers=decoder_num_layers_sub,
                               dropout=dropout_decoder,
                               residual=decoder_residual,
                               dense_residual=decoder_dense_residual))

            ###################################
            # Attention layer (sub)
            ###################################
            setattr(
                self, 'attend_1_' + dir,
                AttentionMechanism(
                    encoder_num_units=self.encoder_num_units_sub,
                    decoder_num_units=decoder_num_units_sub,
                    attention_type=attention_type,
                    attention_dim=attention_dim,
                    sharpening_factor=sharpening_factor,
                    sigmoid_smoothing=sigmoid_smoothing,
                    out_channels=attention_conv_num_channels,
                    kernel_size=attention_conv_width,
                    num_heads=num_heads_sub))

            ##############################
            # Output layer (sub)
            ##############################
            setattr(
                self, 'W_d_1_' + dir,
                LinearND(decoder_num_units_sub,
                         bottleneck_dim_sub,
                         dropout=dropout_decoder))
            setattr(
                self, 'W_c_1_' + dir,
                LinearND(self.encoder_num_units_sub,
                         bottleneck_dim_sub,
                         dropout=dropout_decoder))
            setattr(self, 'fc_1_' + dir,
                    LinearND(bottleneck_dim_sub, self.num_classes_sub))

            ##############################
            # Embedding (sub)
            ##############################
            if label_smoothing_prob > 0:
                self.embed_1 = Embedding_LS(
                    num_classes=self.num_classes_sub,
                    embedding_dim=embedding_dim_sub,
                    dropout=dropout_embedding,
                    label_smoothing_prob=label_smoothing_prob)
            else:
                self.embed_1 = Embedding(num_classes=self.num_classes_sub,
                                         embedding_dim=embedding_dim_sub,
                                         dropout=dropout_embedding,
                                         ignore_index=-1)

        ##############################
        # CTC (sub)
        ##############################
        if ctc_loss_weight_sub > 0:
            self.fc_ctc_1 = LinearND(self.encoder_num_units_sub,
                                     num_classes_sub + 1)

            # Set CTC decoders
            self._decode_ctc_greedy_np = GreedyDecoder(blank_index=0)
            self._decode_ctc_beam_np = BeamSearchDecoder(blank_index=0)
            # NOTE: index 0 is reserved for the blank class

        ##################################################
        # Initialize parameters
        ##################################################
        self.init_weights(parameter_init,
                          distribution=parameter_init_distribution,
                          ignore_keys=['bias'])

        # Initialize all biases with 0
        self.init_weights(0, distribution='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if recurrent_weight_orthogonal:
            self.init_weights(parameter_init,
                              distribution='orthogonal',
                              keys=[encoder_type, 'weight'],
                              ignore_keys=['bias'])
            self.init_weights(parameter_init,
                              distribution='orthogonal',
                              keys=[decoder_type, 'weight'],
                              ignore_keys=['bias'])

        # Initialize bias in forget gate with 1
        if init_forget_gate_bias_with_one:
            self.init_forget_gate_bias_with_one()