コード例 #1
0
    def __init__(self, args, vocab, transition_system):
        super(WikiSqlParser, self).__init__(args, vocab, transition_system)

        self.table_header_lstm = nn.LSTM(args.embed_size, int(args.hidden_size / 2), bidirectional=True, batch_first=True)
        self.column_pointer_net = PointerNet(args.hidden_size, args.hidden_size, attention_type=args.column_att)

        self.column_rnn_input = nn.Linear(args.hidden_size, args.embed_size, bias=False)
コード例 #2
0
    def __init__(self,
                 src_vocab,
                 tgt_vocab,
                 embed_size,
                 hidden_size,
                 dropout=0.,
                 cuda=False,
                 src_embed_layer=None,
                 tgt_embed_layer=None):

        super(Seq2SeqWithCopy, self).__init__(src_vocab,
                                              tgt_vocab,
                                              embed_size,
                                              hidden_size,
                                              dropout=dropout,
                                              src_embed_layer=src_embed_layer,
                                              tgt_embed_layer=tgt_embed_layer,
                                              cuda=cuda)

        # pointer net to the source
        self.src_pointer_net = PointerNet(src_encoding_size=hidden_size * 2,
                                          query_vec_size=hidden_size)

        self.tgt_token_predictor = nn.Linear(hidden_size, 2)
コード例 #3
0
    def __init__(self, args, vocab, transition_system):
        super(GRUParser, self).__init__()

        self.args = args
        self.vocab = vocab

        self.transition_system = transition_system
        self.grammar = self.transition_system.grammar

        # Embedding layers

        # source token embedding
        self.src_embed = nn.Embedding(len(vocab.source), args.embed_size)

        # embedding table of ASDL production rules (constructors), one for each ApplyConstructor action,
        # the last entry is the embedding for Reduce action
        self.production_embed = nn.Embedding(
            len(transition_system.grammar) + 1, args.action_embed_size)

        # embedding table for target primitive tokens
        self.primitive_embed = nn.Embedding(len(vocab.primitive),
                                            args.action_embed_size)

        # embedding table for ASDL fields in constructors
        self.field_embed = nn.Embedding(len(transition_system.grammar.fields),
                                        args.field_embed_size)

        # embedding table for ASDL types
        self.type_embed = nn.Embedding(len(transition_system.grammar.types),
                                       args.type_embed_size)

        nn.init.xavier_normal_(self.src_embed.weight.data)
        nn.init.xavier_normal_(self.production_embed.weight.data)
        nn.init.xavier_normal_(self.primitive_embed.weight.data)
        nn.init.xavier_normal_(self.field_embed.weight.data)
        nn.init.xavier_normal_(self.type_embed.weight.data)

        # LSTMs
        if args.lstm == 'lstm':
            self.encoder_lstm = nn.GRU(args.embed_size,
                                       int(args.hidden_size / 2),
                                       bidirectional=True)

            input_dim = args.action_embed_size  # previous action
            # frontier info
            input_dim += args.action_embed_size * (
                not args.no_parent_production_embed)
            input_dim += args.field_embed_size * (
                not args.no_parent_field_embed)
            input_dim += args.type_embed_size * (
                not args.no_parent_field_type_embed)
            input_dim += args.hidden_size * (not args.no_parent_state)

            input_dim += args.att_vec_size * (not args.no_input_feed
                                              )  # input feeding

            self.decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size)
        elif args.lstm == 'parent_feed':
            self.encoder_lstm = nn.LSTM(args.embed_size,
                                        int(args.hidden_size / 2),
                                        bidirectional=True)
            from .lstm import ParentFeedingLSTMCell

            input_dim = args.action_embed_size  # previous action
            # frontier info
            input_dim += args.action_embed_size * (
                not args.no_parent_production_embed)
            input_dim += args.field_embed_size * (
                not args.no_parent_field_embed)
            input_dim += args.type_embed_size * (
                not args.no_parent_field_type_embed)
            input_dim += args.att_vec_size * (not args.no_input_feed
                                              )  # input feeding

            self.decoder_lstm = ParentFeedingLSTMCell(input_dim,
                                                      args.hidden_size)
        else:
            raise ValueError('Unknown LSTM type %s' % args.lstm)

        if args.no_copy is False:
            # pointer net for copying tokens from source side
            self.src_pointer_net = PointerNet(
                query_vec_size=args.att_vec_size,
                src_encoding_size=args.hidden_size)

            # given the decoder's hidden state, predict whether to copy or generate a target primitive token
            # output: [p(gen(token)) | s_t, p(copy(token)) | s_t]

            self.primitive_predictor = nn.Linear(args.att_vec_size, 2)

        if args.primitive_token_label_smoothing:
            self.label_smoothing = LabelSmoothing(
                args.primitive_token_label_smoothing,
                len(self.vocab.primitive),
                ignore_indices=[0, 1, 2])

        # initialize the decoder's state and cells with encoder hidden states
        self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size)

        # attention: dot product attention
        # project source encoding to decoder rnn's hidden space

        self.att_src_linear = nn.Linear(args.hidden_size,
                                        args.hidden_size,
                                        bias=False)

        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the `attentional vector` in (Luong et al., 2015)

        self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size,
                                        args.att_vec_size,
                                        bias=False)

        # bias for predicting ApplyConstructor and GenToken actions
        self.production_readout_b = nn.Parameter(
            torch.FloatTensor(len(transition_system.grammar) + 1).zero_())
        self.tgt_token_readout_b = nn.Parameter(
            torch.FloatTensor(len(vocab.primitive)).zero_())

        if args.no_query_vec_to_action_map:
            # if there is no additional linear layer between the attentional vector (i.e., the query vector)
            # and the final softmax layer over target actions, we use the attentional vector to compute action
            # probabilities

            assert args.att_vec_size == args.action_embed_size
            self.production_readout = lambda q: F.linear(
                q, self.production_embed.weight, self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                q, self.primitive_embed.weight, self.tgt_token_readout_b)
        else:
            # by default, we feed the attentional vector (i.e., the query vector) into a linear layer without bias, and
            # compute action probabilities by dot-producting the resulting vector and (GenToken, ApplyConstructor) action embeddings
            # i.e., p(action) = query_vec^T \cdot W \cdot embedding

            self.query_vec_to_action_embed = nn.Linear(
                args.att_vec_size,
                args.embed_size,
                bias=args.readout == 'non_linear')
            if args.query_vec_to_action_diff_map:
                # use different linear transformations for GenToken and ApplyConstructor actions
                self.query_vec_to_primitive_embed = nn.Linear(
                    args.att_vec_size,
                    args.embed_size,
                    bias=args.readout == 'non_linear')
            else:
                self.query_vec_to_primitive_embed = self.query_vec_to_action_embed

            self.read_out_act = torch.tanh if args.readout == 'non_linear' else nn_utils.identity

            self.production_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_action_embed(q)), self.
                production_embed.weight, self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_primitive_embed(q)), self.
                primitive_embed.weight, self.tgt_token_readout_b)

        # dropout layer
        self.dropout = nn.Dropout(args.dropout)

        if args.cuda:
            self.new_long_tensor = torch.cuda.LongTensor
            self.new_tensor = torch.cuda.FloatTensor
        else:
            self.new_long_tensor = torch.LongTensor
            self.new_tensor = torch.FloatTensor
コード例 #4
0
ファイル: parser.py プロジェクト: thu-spmi/seq2seq-JAE
    def __init__(self, args, vocab, transition_system):
        super(Parser, self).__init__()

        self.args = args
        self.vocab = vocab

        self.transition_system = transition_system
        self.grammar = self.transition_system.grammar

        # Embedding layers
        self.src_embed = nn.Embedding(len(vocab.source), args.embed_size)
        self.production_embed = nn.Embedding(
            len(transition_system.grammar) + 1, args.action_embed_size)
        self.primitive_embed = nn.Embedding(len(vocab.primitive),
                                            args.action_embed_size)
        self.field_embed = nn.Embedding(len(transition_system.grammar.fields),
                                        args.field_embed_size)
        self.type_embed = nn.Embedding(len(transition_system.grammar.types),
                                       args.type_embed_size)

        nn.init.xavier_normal(self.src_embed.weight.data)
        nn.init.xavier_normal(self.production_embed.weight.data)
        nn.init.xavier_normal(self.primitive_embed.weight.data)
        nn.init.xavier_normal(self.field_embed.weight.data)
        nn.init.xavier_normal(self.type_embed.weight.data)

        # LSTMs
        if args.lstm == 'lstm':
            self.encoder_lstm = nn.LSTM(args.embed_size,
                                        args.hidden_size // 2,
                                        bidirectional=True)
            self.decoder_lstm = nn.LSTMCell(
                args.action_embed_size +  # previous action
                args.action_embed_size + args.field_embed_size +
                args.type_embed_size +  # frontier info
                args.hidden_size +  # parent hidden state
                args.hidden_size,  # input feeding
                args.hidden_size)
        else:
            from .lstm import LSTM, LSTMCell
            self.encoder_lstm = LSTM(args.embed_size,
                                     args.hidden_size // 2,
                                     bidirectional=True,
                                     dropout=args.dropout)
            self.decoder_lstm = LSTMCell(
                args.action_embed_size +  # previous action
                args.action_embed_size + args.field_embed_size +
                args.type_embed_size +  # frontier info
                args.hidden_size + args.hidden_size,  # parent hidden state
                args.hidden_size,
                dropout=args.dropout)

        # pointer net
        self.src_pointer_net = PointerNet(args.hidden_size, args.hidden_size)

        self.primitive_predictor = nn.Linear(args.hidden_size, 2)

        # initialize the decoder's state and cells with encoder hidden states
        self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size)

        # attention: dot product attention
        # project source encoding to decoder rnn's h space
        self.att_src_linear = nn.Linear(args.hidden_size,
                                        args.hidden_size,
                                        bias=False)

        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the `attentional vector` in (Luong et al., 2015)
        self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size,
                                        args.hidden_size,
                                        bias=False)

        # embedding layers
        self.query_vec_to_embed = nn.Linear(args.hidden_size,
                                            args.embed_size,
                                            bias=False)
        self.production_readout_b = nn.Parameter(
            torch.FloatTensor(len(transition_system.grammar) + 1).zero_())
        self.tgt_token_readout_b = nn.Parameter(
            torch.FloatTensor(len(vocab.primitive)).zero_())
        self.production_readout = self.production_readout_func
        self.tgt_token_readout = self.tgt_token_readout_func
        # self.production_readout = nn.Linear(args.hidden_size, len(transition_system.grammar) + 1)
        # self.tgt_token_readout = nn.Linear(args.hidden_size, len(vocab.primitive))

        # dropout layer
        self.dropout = nn.Dropout(args.dropout)

        if args.cuda:
            self.new_long_tensor = torch.cuda.LongTensor
            self.new_tensor = torch.cuda.FloatTensor
        else:
            self.new_long_tensor = torch.LongTensor
            self.new_tensor = torch.FloatTensor
コード例 #5
0
ファイル: parser.py プロジェクト: zorrock/tranX
    def __init__(self, args, vocab, transition_system):
        super(Parser, self).__init__()

        self.args = args
        self.vocab = vocab

        self.transition_system = transition_system
        self.grammar = self.transition_system.grammar

        # Embedding layers
        self.src_embed = nn.Embedding(len(vocab.source), args.embed_size)
        self.production_embed = nn.Embedding(
            len(transition_system.grammar) + 1, args.action_embed_size)
        self.primitive_embed = nn.Embedding(len(vocab.primitive),
                                            args.action_embed_size)
        self.field_embed = nn.Embedding(len(transition_system.grammar.fields),
                                        args.field_embed_size)
        self.type_embed = nn.Embedding(len(transition_system.grammar.types),
                                       args.type_embed_size)

        nn.init.xavier_normal(self.src_embed.weight.data)
        nn.init.xavier_normal(self.production_embed.weight.data)
        nn.init.xavier_normal(self.primitive_embed.weight.data)
        nn.init.xavier_normal(self.field_embed.weight.data)
        nn.init.xavier_normal(self.type_embed.weight.data)

        # LSTMs
        if args.lstm == 'lstm':
            self.encoder_lstm = nn.LSTM(args.embed_size,
                                        int(args.hidden_size / 2),
                                        bidirectional=True)

            input_dim = args.action_embed_size  # previous action
            # frontier info
            input_dim += args.action_embed_size * (
                not args.no_parent_production_embed)
            input_dim += args.field_embed_size * (
                not args.no_parent_field_embed)
            input_dim += args.type_embed_size * (
                not args.no_parent_field_type_embed)
            input_dim += args.hidden_size * (not args.no_parent_state)

            input_dim += args.att_vec_size * (not args.no_input_feed
                                              )  # input feeding

            self.decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size)
        elif args.lstm == 'parent_feed':
            self.encoder_lstm = nn.LSTM(args.embed_size,
                                        int(args.hidden_size / 2),
                                        bidirectional=True)
            from .lstm import ParentFeedingLSTMCell

            input_dim = args.action_embed_size  # previous action
            # frontier info
            input_dim += args.action_embed_size * (
                not args.no_parent_production_embed)
            input_dim += args.field_embed_size * (
                not args.no_parent_field_embed)
            input_dim += args.type_embed_size * (
                not args.no_parent_field_type_embed)
            input_dim += args.att_vec_size * (not args.no_input_feed
                                              )  # input feeding

            self.decoder_lstm = ParentFeedingLSTMCell(input_dim,
                                                      args.hidden_size)
        else:
            from lstm import LSTM, LSTMCell
            self.encoder_lstm = LSTM(args.embed_size,
                                     args.hidden_size / 2,
                                     bidirectional=True,
                                     dropout=args.dropout)
            self.decoder_lstm = LSTMCell(
                args.action_embed_size +  # previous action
                args.action_embed_size + args.field_embed_size +
                args.type_embed_size +  # frontier info
                args.hidden_size,  # parent hidden state
                args.hidden_size,
                dropout=args.dropout)

        # pointer net
        self.src_pointer_net = PointerNet(query_vec_size=args.att_vec_size,
                                          src_encoding_size=args.hidden_size)

        self.primitive_predictor = nn.Linear(args.att_vec_size, 2)

        # initialize the decoder's state and cells with encoder hidden states
        self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size)

        # attention: dot product attention
        # project source encoding to decoder rnn's h space
        self.att_src_linear = nn.Linear(args.hidden_size,
                                        args.hidden_size,
                                        bias=False)

        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the `attentional vector` in (Luong et al., 2015)
        self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size,
                                        args.att_vec_size,
                                        bias=False)

        # embedding layers
        self.production_readout_b = nn.Parameter(
            torch.FloatTensor(len(transition_system.grammar) + 1).zero_())
        self.tgt_token_readout_b = nn.Parameter(
            torch.FloatTensor(len(vocab.primitive)).zero_())

        if args.no_query_vec_to_action_map:
            assert args.att_vec_size == args.action_embed_size
            self.production_readout = lambda q: F.linear(
                q, self.production_embed.weight, self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                q, self.primitive_embed.weight, self.tgt_token_readout_b)
        else:
            self.query_vec_to_action_embed = nn.Linear(
                args.att_vec_size,
                args.embed_size,
                bias=args.readout == 'non_linear')
            if args.query_vec_to_action_diff_map:
                self.query_vec_to_primitive_embed = nn.Linear(
                    args.att_vec_size,
                    args.embed_size,
                    bias=args.readout == 'non_linear')
            else:
                self.query_vec_to_primitive_embed = self.query_vec_to_action_embed

            self.read_out_act = F.tanh if args.readout == 'non_linear' else nn_utils.identity

            self.production_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_action_embed(q)), self.
                production_embed.weight, self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_primitive_embed(q)), self.
                primitive_embed.weight, self.tgt_token_readout_b)

        # dropout layer
        self.dropout = nn.Dropout(args.dropout)

        if args.cuda:
            self.new_long_tensor = torch.cuda.LongTensor
            self.new_tensor = torch.cuda.FloatTensor
        else:
            self.new_long_tensor = torch.LongTensor
            self.new_tensor = torch.FloatTensor
コード例 #6
0
ファイル: transformer.py プロジェクト: kzCassie/ucl_nlp
    def __init__(self, args, vocab, transition_system):
        super(TransformerParser, self).__init__()

        self.args = args
        self.vocab = vocab

        self.transition_system = transition_system
        self.grammar = self.transition_system.grammar

        # Embedding layers

        # source token embedding
        self.src_embed = nn.Embedding(len(vocab.source), args.embed_size)

        # embedding table of ASDL production rules (constructors), one for each ApplyConstructor action,
        # the last entry is the embedding for Reduce action
        self.production_embed = nn.Embedding(
            len(transition_system.grammar) + 1, args.action_embed_size)

        # embedding table for target primitive tokens
        self.primitive_embed = nn.Embedding(len(vocab.primitive),
                                            args.action_embed_size)

        # embedding table for ASDL fields in constructors
        self.field_embed = nn.Embedding(len(transition_system.grammar.fields),
                                        args.field_embed_size)

        # embedding table for ASDL types
        self.type_embed = nn.Embedding(len(transition_system.grammar.types),
                                       args.type_embed_size)

        nn.init.xavier_normal_(self.src_embed.weight.data)
        nn.init.xavier_normal_(self.production_embed.weight.data)
        nn.init.xavier_normal_(self.primitive_embed.weight.data)
        nn.init.xavier_normal_(self.field_embed.weight.data)
        nn.init.xavier_normal_(self.type_embed.weight.data)

        # decoder input dimension
        input_dim = args.action_embed_size  # previous action
        # frontier info
        input_dim += args.action_embed_size * (
            not args.no_parent_production_embed)
        input_dim += args.field_embed_size * (not args.no_parent_field_embed)
        input_dim += args.type_embed_size * (
            not args.no_parent_field_type_embed)
        self.input_dim = input_dim

        #### Transformer ####
        # Transformer Encoder
        transformer_encoder_layer = nn.TransformerEncoderLayer(
            args.hidden_size, nhead=args.enc_nhead)
        self.transformer_encoder = nn.TransformerEncoder(
            transformer_encoder_layer, num_layers=args.enc_nlayer)
        self.src_pos_encoder = PositionalEncoding(args.hidden_size,
                                                  dropout=0.1)

        # Transformer Decoder
        transformer_decoder_layer = nn.TransformerDecoderLayer(
            args.hidden_size, nhead=args.dec_nhead)
        self.transformer_decoder = nn.TransformerDecoder(
            transformer_decoder_layer, num_layers=args.dec_nlayer)
        self.tgt_pos_encoder = PositionalEncoding(args.hidden_size,
                                                  dropout=0.1)

        # Transformer decoder must accepts vectors of the same hidden_size as the encoder.
        self.src_enc_linear = nn.Linear(args.embed_size, args.hidden_size)
        self.tgt_dec_linear = nn.Linear(self.input_dim, args.hidden_size)
        #####################

        if args.no_copy is False:
            # pointer net for copying tokens from source side
            self.src_pointer_net = PointerNet(
                query_vec_size=args.hidden_size,
                src_encoding_size=args.hidden_size)

            # given the decoder's hidden state, predict whether to copy or generate a target primitive token
            # output: [p(gen(token)) | s_t, p(copy(token)) | s_t]

            self.primitive_predictor = nn.Linear(args.hidden_size, 2)

        if args.primitive_token_label_smoothing:
            self.label_smoothing = LabelSmoothing(
                args.primitive_token_label_smoothing,
                len(self.vocab.primitive),
                ignore_indices=[0, 1, 2])

        # bias for predicting ApplyConstructor and GenToken actions
        self.production_readout_b = nn.Parameter(
            torch.FloatTensor(len(transition_system.grammar) + 1).zero_())
        self.tgt_token_readout_b = nn.Parameter(
            torch.FloatTensor(len(vocab.primitive)).zero_())

        if args.no_query_vec_to_action_map:
            # if there is no additional linear layer between the attentional vector (i.e., the query vector)
            # and the final softmax layer over target actions, we use the attentional vector to compute action
            # probabilities

            assert args.att_vec_size == args.action_embed_size
            self.production_readout = lambda q: F.linear(
                q, self.production_embed.weight, self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                q, self.primitive_embed.weight, self.tgt_token_readout_b)
        else:
            # by default, we feed the attentional vector (i.e., the query vector) into a linear layer without bias, and
            # compute action probabilities by dot-producting the resulting vector and (GenToken, ApplyConstructor) action embeddings
            # i.e., p(action) = query_vec^T \cdot W \cdot embedding

            self.query_vec_to_action_embed = nn.Linear(
                args.att_vec_size,
                args.embed_size,
                bias=args.readout == 'non_linear')
            if args.query_vec_to_action_diff_map:
                # use different linear transformations for GenToken and ApplyConstructor actions
                self.query_vec_to_primitive_embed = nn.Linear(
                    args.att_vec_size,
                    args.embed_size,
                    bias=args.readout == 'non_linear')
            else:
                self.query_vec_to_primitive_embed = self.query_vec_to_action_embed

            self.read_out_act = torch.tanh if args.readout == 'non_linear' else nn_utils.identity

            self.production_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_action_embed(q)), self.
                production_embed.weight, self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_primitive_embed(q)), self.
                primitive_embed.weight, self.tgt_token_readout_b)

        # dropout layer
        self.dropout = nn.Dropout(args.dropout)

        if args.cuda:
            self.new_long_tensor = torch.cuda.LongTensor
            self.new_tensor = torch.cuda.FloatTensor
        else:
            self.new_long_tensor = torch.LongTensor
            self.new_tensor = torch.FloatTensor
コード例 #7
0
    def __init__(self, args, vocab, transition_system):
        super(TransformerParser, self).__init__()

        self.args = args
        self.vocab = vocab
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() and args.cuda else "cpu")

        self.transition_system = transition_system
        self.grammar = self.transition_system.grammar

        # Transformer parameters
        self.num_layers = args.num_layers
        self.d_model = args.hidden_size
        self.d_ff = args.ffn_size
        self.h = args.num_heads
        self.dropout = args.dropout_model
        self.position = PositionalEncoding(self.d_model, self.dropout)
        attn = MultiHeadedAttention(self.h, self.d_model)
        parent_attn = StrictMultiHeadedAttention(self.h, 1, self.d_model)
        ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout)

        # Embedding layers

        # source token embedding
        self.src_embed = nn.Sequential(
            Embeddings(self.d_model, len(vocab.source)),
            copy.deepcopy(self.position))

        # embedding table of ASDL production rules (constructors), one for each ApplyConstructor action,
        # the last entry is the embedding for Reduce action
        self.action_embed_size = args.action_embed_size
        self.field_embed_size = args.field_embed_size
        self.type_embed_size = args.type_embed_size

        assert self.d_model == (self.action_embed_size +
                                self.action_embed_size *
                                (not self.args.no_parent_production_embed) +
                                self.field_embed_size *
                                (not self.args.no_parent_field_embed) +
                                self.type_embed_size *
                                (not self.args.no_parent_field_type_embed))

        self.production_embed = Embeddings(self.action_embed_size,
                                           len(transition_system.grammar) + 1)

        # embedding table for target primitive tokens
        self.primitive_embed = Embeddings(self.action_embed_size,
                                          len(vocab.primitive))

        # embedding table for ASDL fields in constructors
        self.field_embed = Embeddings(self.field_embed_size,
                                      len(transition_system.grammar.fields))

        # embedding table for ASDL types
        self.type_embed = Embeddings(self.type_embed_size,
                                     len(transition_system.grammar.types))

        assert args.lstm == "transformer"
        self.encoder = Encoder(
            EncoderLayer(self.d_model, copy.deepcopy(attn), copy.deepcopy(ff),
                         self.dropout), self.num_layers).to(self.device)
        self.decoder = Decoder(
            DecoderLayer(self.d_model, copy.deepcopy(parent_attn),
                         copy.deepcopy(attn), copy.deepcopy(ff), self.dropout),
            self.num_layers,
        ).to(self.device)

        if args.no_copy is False:
            # pointer net for copying tokens from source side
            self.src_pointer_net = PointerNet(
                query_vec_size=args.att_vec_size,
                src_encoding_size=args.hidden_size)

            # given the decoder's hidden state, predict whether to copy or generate a target primitive token
            # output: [p(gen(token)) | s_t, p(copy(token)) | s_t]

            self.primitive_predictor = nn.Linear(args.att_vec_size, 2)

        if args.primitive_token_label_smoothing:
            self.label_smoothing = LabelSmoothing(
                args.primitive_token_label_smoothing,
                len(self.vocab.primitive),
                ignore_indices=[0, 1, 2])

        # initialize the decoder's state and cells with encoder hidden states
        self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size)

        # attention: dot product attention
        # project source encoding to decoder rnn's hidden space

        self.att_src_linear = nn.Linear(args.hidden_size,
                                        args.hidden_size,
                                        bias=False)

        # transformation of decoder hidden states and context vectors before reading out target words
        # this produces the `attentional vector` in (Luong et al., 2015)

        self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size,
                                        args.att_vec_size,
                                        bias=False)

        # bias for predicting ApplyConstructor and GenToken actions
        self.production_readout_b = nn.Parameter(
            torch.zeros(len(transition_system.grammar) + 1,
                        dtype=torch.float32))
        self.tgt_token_readout_b = nn.Parameter(
            torch.zeros(len(vocab.primitive), dtype=torch.float32))

        if args.no_query_vec_to_action_map:
            # if there is no additional linear layer between the attentional vector (i.e., the query vector)
            # and the final softmax layer over target actions, we use the attentional vector to compute action
            # probabilities

            assert args.att_vec_size == args.action_embed_size
            self.production_readout = lambda q: F.linear(
                q * math.sqrt(self.d_model), self.production_embed.lut.weight,
                self.production_readout_b)
            self.tgt_token_readout = lambda q: F.linear(
                q * math.sqrt(self.d_model), self.primitive_embed.lut.weight,
                self.tgt_token_readout_b)
        else:
            # by default, we feed the attentional vector (i.e., the query vector) into a linear layer without bias, and
            # compute action probabilities by dot-producting the resulting vector and (GenToken, ApplyConstructor) action embeddings
            # i.e., p(action) = query_vec^T \cdot W \cdot embedding

            self.query_vec_to_action_embed = nn.Linear(
                args.att_vec_size,
                args.action_embed_size,
                bias=args.readout == "non_linear")
            if args.query_vec_to_action_diff_map:
                # use different linear transformations for GenToken and ApplyConstructor actions
                self.query_vec_to_primitive_embed = nn.Linear(
                    args.att_vec_size,
                    args.action_embed_size,
                    bias=args.readout == "non_linear")
            else:
                self.query_vec_to_primitive_embed = self.query_vec_to_action_embed

            self.read_out_act = F.tanh if args.readout == "non_linear" else nn_utils.identity

            self.production_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_action_embed(q)) * math.
                sqrt(self.d_model),
                self.production_embed.lut.weight,
                self.production_readout_b,
            )
            self.tgt_token_readout = lambda q: F.linear(
                self.read_out_act(self.query_vec_to_primitive_embed(q)) * math.
                sqrt(self.d_model),
                self.primitive_embed.lut.weight,
                self.tgt_token_readout_b,
            )

        # dropout layer
        self.dropout = nn.Dropout(args.dropout)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)