def __init__(self, args, vocab, transition_system): super(WikiSqlParser, self).__init__(args, vocab, transition_system) self.table_header_lstm = nn.LSTM(args.embed_size, int(args.hidden_size / 2), bidirectional=True, batch_first=True) self.column_pointer_net = PointerNet(args.hidden_size, args.hidden_size, attention_type=args.column_att) self.column_rnn_input = nn.Linear(args.hidden_size, args.embed_size, bias=False)
def __init__(self, src_vocab, tgt_vocab, embed_size, hidden_size, dropout=0., cuda=False, src_embed_layer=None, tgt_embed_layer=None): super(Seq2SeqWithCopy, self).__init__(src_vocab, tgt_vocab, embed_size, hidden_size, dropout=dropout, src_embed_layer=src_embed_layer, tgt_embed_layer=tgt_embed_layer, cuda=cuda) # pointer net to the source self.src_pointer_net = PointerNet(src_encoding_size=hidden_size * 2, query_vec_size=hidden_size) self.tgt_token_predictor = nn.Linear(hidden_size, 2)
def __init__(self, args, vocab, transition_system): super(GRUParser, self).__init__() self.args = args self.vocab = vocab self.transition_system = transition_system self.grammar = self.transition_system.grammar # Embedding layers # source token embedding self.src_embed = nn.Embedding(len(vocab.source), args.embed_size) # embedding table of ASDL production rules (constructors), one for each ApplyConstructor action, # the last entry is the embedding for Reduce action self.production_embed = nn.Embedding( len(transition_system.grammar) + 1, args.action_embed_size) # embedding table for target primitive tokens self.primitive_embed = nn.Embedding(len(vocab.primitive), args.action_embed_size) # embedding table for ASDL fields in constructors self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size) # embedding table for ASDL types self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size) nn.init.xavier_normal_(self.src_embed.weight.data) nn.init.xavier_normal_(self.production_embed.weight.data) nn.init.xavier_normal_(self.primitive_embed.weight.data) nn.init.xavier_normal_(self.field_embed.weight.data) nn.init.xavier_normal_(self.type_embed.weight.data) # LSTMs if args.lstm == 'lstm': self.encoder_lstm = nn.GRU(args.embed_size, int(args.hidden_size / 2), bidirectional=True) input_dim = args.action_embed_size # previous action # frontier info input_dim += args.action_embed_size * ( not args.no_parent_production_embed) input_dim += args.field_embed_size * ( not args.no_parent_field_embed) input_dim += args.type_embed_size * ( not args.no_parent_field_type_embed) input_dim += args.hidden_size * (not args.no_parent_state) input_dim += args.att_vec_size * (not args.no_input_feed ) # input feeding self.decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size) elif args.lstm == 'parent_feed': self.encoder_lstm = nn.LSTM(args.embed_size, int(args.hidden_size / 2), bidirectional=True) from .lstm import ParentFeedingLSTMCell input_dim = args.action_embed_size # previous action # frontier info input_dim += args.action_embed_size * ( not args.no_parent_production_embed) input_dim += args.field_embed_size * ( not args.no_parent_field_embed) input_dim += args.type_embed_size * ( not args.no_parent_field_type_embed) input_dim += args.att_vec_size * (not args.no_input_feed ) # input feeding self.decoder_lstm = ParentFeedingLSTMCell(input_dim, args.hidden_size) else: raise ValueError('Unknown LSTM type %s' % args.lstm) if args.no_copy is False: # pointer net for copying tokens from source side self.src_pointer_net = PointerNet( query_vec_size=args.att_vec_size, src_encoding_size=args.hidden_size) # given the decoder's hidden state, predict whether to copy or generate a target primitive token # output: [p(gen(token)) | s_t, p(copy(token)) | s_t] self.primitive_predictor = nn.Linear(args.att_vec_size, 2) if args.primitive_token_label_smoothing: self.label_smoothing = LabelSmoothing( args.primitive_token_label_smoothing, len(self.vocab.primitive), ignore_indices=[0, 1, 2]) # initialize the decoder's state and cells with encoder hidden states self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size) # attention: dot product attention # project source encoding to decoder rnn's hidden space self.att_src_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False) # transformation of decoder hidden states and context vectors before reading out target words # this produces the `attentional vector` in (Luong et al., 2015) self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False) # bias for predicting ApplyConstructor and GenToken actions self.production_readout_b = nn.Parameter( torch.FloatTensor(len(transition_system.grammar) + 1).zero_()) self.tgt_token_readout_b = nn.Parameter( torch.FloatTensor(len(vocab.primitive)).zero_()) if args.no_query_vec_to_action_map: # if there is no additional linear layer between the attentional vector (i.e., the query vector) # and the final softmax layer over target actions, we use the attentional vector to compute action # probabilities assert args.att_vec_size == args.action_embed_size self.production_readout = lambda q: F.linear( q, self.production_embed.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( q, self.primitive_embed.weight, self.tgt_token_readout_b) else: # by default, we feed the attentional vector (i.e., the query vector) into a linear layer without bias, and # compute action probabilities by dot-producting the resulting vector and (GenToken, ApplyConstructor) action embeddings # i.e., p(action) = query_vec^T \cdot W \cdot embedding self.query_vec_to_action_embed = nn.Linear( args.att_vec_size, args.embed_size, bias=args.readout == 'non_linear') if args.query_vec_to_action_diff_map: # use different linear transformations for GenToken and ApplyConstructor actions self.query_vec_to_primitive_embed = nn.Linear( args.att_vec_size, args.embed_size, bias=args.readout == 'non_linear') else: self.query_vec_to_primitive_embed = self.query_vec_to_action_embed self.read_out_act = torch.tanh if args.readout == 'non_linear' else nn_utils.identity self.production_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_action_embed(q)), self. production_embed.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_primitive_embed(q)), self. primitive_embed.weight, self.tgt_token_readout_b) # dropout layer self.dropout = nn.Dropout(args.dropout) if args.cuda: self.new_long_tensor = torch.cuda.LongTensor self.new_tensor = torch.cuda.FloatTensor else: self.new_long_tensor = torch.LongTensor self.new_tensor = torch.FloatTensor
def __init__(self, args, vocab, transition_system): super(Parser, self).__init__() self.args = args self.vocab = vocab self.transition_system = transition_system self.grammar = self.transition_system.grammar # Embedding layers self.src_embed = nn.Embedding(len(vocab.source), args.embed_size) self.production_embed = nn.Embedding( len(transition_system.grammar) + 1, args.action_embed_size) self.primitive_embed = nn.Embedding(len(vocab.primitive), args.action_embed_size) self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size) self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size) nn.init.xavier_normal(self.src_embed.weight.data) nn.init.xavier_normal(self.production_embed.weight.data) nn.init.xavier_normal(self.primitive_embed.weight.data) nn.init.xavier_normal(self.field_embed.weight.data) nn.init.xavier_normal(self.type_embed.weight.data) # LSTMs if args.lstm == 'lstm': self.encoder_lstm = nn.LSTM(args.embed_size, args.hidden_size // 2, bidirectional=True) self.decoder_lstm = nn.LSTMCell( args.action_embed_size + # previous action args.action_embed_size + args.field_embed_size + args.type_embed_size + # frontier info args.hidden_size + # parent hidden state args.hidden_size, # input feeding args.hidden_size) else: from .lstm import LSTM, LSTMCell self.encoder_lstm = LSTM(args.embed_size, args.hidden_size // 2, bidirectional=True, dropout=args.dropout) self.decoder_lstm = LSTMCell( args.action_embed_size + # previous action args.action_embed_size + args.field_embed_size + args.type_embed_size + # frontier info args.hidden_size + args.hidden_size, # parent hidden state args.hidden_size, dropout=args.dropout) # pointer net self.src_pointer_net = PointerNet(args.hidden_size, args.hidden_size) self.primitive_predictor = nn.Linear(args.hidden_size, 2) # initialize the decoder's state and cells with encoder hidden states self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size) # attention: dot product attention # project source encoding to decoder rnn's h space self.att_src_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False) # transformation of decoder hidden states and context vectors before reading out target words # this produces the `attentional vector` in (Luong et al., 2015) self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.hidden_size, bias=False) # embedding layers self.query_vec_to_embed = nn.Linear(args.hidden_size, args.embed_size, bias=False) self.production_readout_b = nn.Parameter( torch.FloatTensor(len(transition_system.grammar) + 1).zero_()) self.tgt_token_readout_b = nn.Parameter( torch.FloatTensor(len(vocab.primitive)).zero_()) self.production_readout = self.production_readout_func self.tgt_token_readout = self.tgt_token_readout_func # self.production_readout = nn.Linear(args.hidden_size, len(transition_system.grammar) + 1) # self.tgt_token_readout = nn.Linear(args.hidden_size, len(vocab.primitive)) # dropout layer self.dropout = nn.Dropout(args.dropout) if args.cuda: self.new_long_tensor = torch.cuda.LongTensor self.new_tensor = torch.cuda.FloatTensor else: self.new_long_tensor = torch.LongTensor self.new_tensor = torch.FloatTensor
def __init__(self, args, vocab, transition_system): super(Parser, self).__init__() self.args = args self.vocab = vocab self.transition_system = transition_system self.grammar = self.transition_system.grammar # Embedding layers self.src_embed = nn.Embedding(len(vocab.source), args.embed_size) self.production_embed = nn.Embedding( len(transition_system.grammar) + 1, args.action_embed_size) self.primitive_embed = nn.Embedding(len(vocab.primitive), args.action_embed_size) self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size) self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size) nn.init.xavier_normal(self.src_embed.weight.data) nn.init.xavier_normal(self.production_embed.weight.data) nn.init.xavier_normal(self.primitive_embed.weight.data) nn.init.xavier_normal(self.field_embed.weight.data) nn.init.xavier_normal(self.type_embed.weight.data) # LSTMs if args.lstm == 'lstm': self.encoder_lstm = nn.LSTM(args.embed_size, int(args.hidden_size / 2), bidirectional=True) input_dim = args.action_embed_size # previous action # frontier info input_dim += args.action_embed_size * ( not args.no_parent_production_embed) input_dim += args.field_embed_size * ( not args.no_parent_field_embed) input_dim += args.type_embed_size * ( not args.no_parent_field_type_embed) input_dim += args.hidden_size * (not args.no_parent_state) input_dim += args.att_vec_size * (not args.no_input_feed ) # input feeding self.decoder_lstm = nn.LSTMCell(input_dim, args.hidden_size) elif args.lstm == 'parent_feed': self.encoder_lstm = nn.LSTM(args.embed_size, int(args.hidden_size / 2), bidirectional=True) from .lstm import ParentFeedingLSTMCell input_dim = args.action_embed_size # previous action # frontier info input_dim += args.action_embed_size * ( not args.no_parent_production_embed) input_dim += args.field_embed_size * ( not args.no_parent_field_embed) input_dim += args.type_embed_size * ( not args.no_parent_field_type_embed) input_dim += args.att_vec_size * (not args.no_input_feed ) # input feeding self.decoder_lstm = ParentFeedingLSTMCell(input_dim, args.hidden_size) else: from lstm import LSTM, LSTMCell self.encoder_lstm = LSTM(args.embed_size, args.hidden_size / 2, bidirectional=True, dropout=args.dropout) self.decoder_lstm = LSTMCell( args.action_embed_size + # previous action args.action_embed_size + args.field_embed_size + args.type_embed_size + # frontier info args.hidden_size, # parent hidden state args.hidden_size, dropout=args.dropout) # pointer net self.src_pointer_net = PointerNet(query_vec_size=args.att_vec_size, src_encoding_size=args.hidden_size) self.primitive_predictor = nn.Linear(args.att_vec_size, 2) # initialize the decoder's state and cells with encoder hidden states self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size) # attention: dot product attention # project source encoding to decoder rnn's h space self.att_src_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False) # transformation of decoder hidden states and context vectors before reading out target words # this produces the `attentional vector` in (Luong et al., 2015) self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False) # embedding layers self.production_readout_b = nn.Parameter( torch.FloatTensor(len(transition_system.grammar) + 1).zero_()) self.tgt_token_readout_b = nn.Parameter( torch.FloatTensor(len(vocab.primitive)).zero_()) if args.no_query_vec_to_action_map: assert args.att_vec_size == args.action_embed_size self.production_readout = lambda q: F.linear( q, self.production_embed.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( q, self.primitive_embed.weight, self.tgt_token_readout_b) else: self.query_vec_to_action_embed = nn.Linear( args.att_vec_size, args.embed_size, bias=args.readout == 'non_linear') if args.query_vec_to_action_diff_map: self.query_vec_to_primitive_embed = nn.Linear( args.att_vec_size, args.embed_size, bias=args.readout == 'non_linear') else: self.query_vec_to_primitive_embed = self.query_vec_to_action_embed self.read_out_act = F.tanh if args.readout == 'non_linear' else nn_utils.identity self.production_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_action_embed(q)), self. production_embed.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_primitive_embed(q)), self. primitive_embed.weight, self.tgt_token_readout_b) # dropout layer self.dropout = nn.Dropout(args.dropout) if args.cuda: self.new_long_tensor = torch.cuda.LongTensor self.new_tensor = torch.cuda.FloatTensor else: self.new_long_tensor = torch.LongTensor self.new_tensor = torch.FloatTensor
def __init__(self, args, vocab, transition_system): super(TransformerParser, self).__init__() self.args = args self.vocab = vocab self.transition_system = transition_system self.grammar = self.transition_system.grammar # Embedding layers # source token embedding self.src_embed = nn.Embedding(len(vocab.source), args.embed_size) # embedding table of ASDL production rules (constructors), one for each ApplyConstructor action, # the last entry is the embedding for Reduce action self.production_embed = nn.Embedding( len(transition_system.grammar) + 1, args.action_embed_size) # embedding table for target primitive tokens self.primitive_embed = nn.Embedding(len(vocab.primitive), args.action_embed_size) # embedding table for ASDL fields in constructors self.field_embed = nn.Embedding(len(transition_system.grammar.fields), args.field_embed_size) # embedding table for ASDL types self.type_embed = nn.Embedding(len(transition_system.grammar.types), args.type_embed_size) nn.init.xavier_normal_(self.src_embed.weight.data) nn.init.xavier_normal_(self.production_embed.weight.data) nn.init.xavier_normal_(self.primitive_embed.weight.data) nn.init.xavier_normal_(self.field_embed.weight.data) nn.init.xavier_normal_(self.type_embed.weight.data) # decoder input dimension input_dim = args.action_embed_size # previous action # frontier info input_dim += args.action_embed_size * ( not args.no_parent_production_embed) input_dim += args.field_embed_size * (not args.no_parent_field_embed) input_dim += args.type_embed_size * ( not args.no_parent_field_type_embed) self.input_dim = input_dim #### Transformer #### # Transformer Encoder transformer_encoder_layer = nn.TransformerEncoderLayer( args.hidden_size, nhead=args.enc_nhead) self.transformer_encoder = nn.TransformerEncoder( transformer_encoder_layer, num_layers=args.enc_nlayer) self.src_pos_encoder = PositionalEncoding(args.hidden_size, dropout=0.1) # Transformer Decoder transformer_decoder_layer = nn.TransformerDecoderLayer( args.hidden_size, nhead=args.dec_nhead) self.transformer_decoder = nn.TransformerDecoder( transformer_decoder_layer, num_layers=args.dec_nlayer) self.tgt_pos_encoder = PositionalEncoding(args.hidden_size, dropout=0.1) # Transformer decoder must accepts vectors of the same hidden_size as the encoder. self.src_enc_linear = nn.Linear(args.embed_size, args.hidden_size) self.tgt_dec_linear = nn.Linear(self.input_dim, args.hidden_size) ##################### if args.no_copy is False: # pointer net for copying tokens from source side self.src_pointer_net = PointerNet( query_vec_size=args.hidden_size, src_encoding_size=args.hidden_size) # given the decoder's hidden state, predict whether to copy or generate a target primitive token # output: [p(gen(token)) | s_t, p(copy(token)) | s_t] self.primitive_predictor = nn.Linear(args.hidden_size, 2) if args.primitive_token_label_smoothing: self.label_smoothing = LabelSmoothing( args.primitive_token_label_smoothing, len(self.vocab.primitive), ignore_indices=[0, 1, 2]) # bias for predicting ApplyConstructor and GenToken actions self.production_readout_b = nn.Parameter( torch.FloatTensor(len(transition_system.grammar) + 1).zero_()) self.tgt_token_readout_b = nn.Parameter( torch.FloatTensor(len(vocab.primitive)).zero_()) if args.no_query_vec_to_action_map: # if there is no additional linear layer between the attentional vector (i.e., the query vector) # and the final softmax layer over target actions, we use the attentional vector to compute action # probabilities assert args.att_vec_size == args.action_embed_size self.production_readout = lambda q: F.linear( q, self.production_embed.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( q, self.primitive_embed.weight, self.tgt_token_readout_b) else: # by default, we feed the attentional vector (i.e., the query vector) into a linear layer without bias, and # compute action probabilities by dot-producting the resulting vector and (GenToken, ApplyConstructor) action embeddings # i.e., p(action) = query_vec^T \cdot W \cdot embedding self.query_vec_to_action_embed = nn.Linear( args.att_vec_size, args.embed_size, bias=args.readout == 'non_linear') if args.query_vec_to_action_diff_map: # use different linear transformations for GenToken and ApplyConstructor actions self.query_vec_to_primitive_embed = nn.Linear( args.att_vec_size, args.embed_size, bias=args.readout == 'non_linear') else: self.query_vec_to_primitive_embed = self.query_vec_to_action_embed self.read_out_act = torch.tanh if args.readout == 'non_linear' else nn_utils.identity self.production_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_action_embed(q)), self. production_embed.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_primitive_embed(q)), self. primitive_embed.weight, self.tgt_token_readout_b) # dropout layer self.dropout = nn.Dropout(args.dropout) if args.cuda: self.new_long_tensor = torch.cuda.LongTensor self.new_tensor = torch.cuda.FloatTensor else: self.new_long_tensor = torch.LongTensor self.new_tensor = torch.FloatTensor
def __init__(self, args, vocab, transition_system): super(TransformerParser, self).__init__() self.args = args self.vocab = vocab self.device = torch.device( "cuda" if torch.cuda.is_available() and args.cuda else "cpu") self.transition_system = transition_system self.grammar = self.transition_system.grammar # Transformer parameters self.num_layers = args.num_layers self.d_model = args.hidden_size self.d_ff = args.ffn_size self.h = args.num_heads self.dropout = args.dropout_model self.position = PositionalEncoding(self.d_model, self.dropout) attn = MultiHeadedAttention(self.h, self.d_model) parent_attn = StrictMultiHeadedAttention(self.h, 1, self.d_model) ff = PositionwiseFeedForward(self.d_model, self.d_ff, self.dropout) # Embedding layers # source token embedding self.src_embed = nn.Sequential( Embeddings(self.d_model, len(vocab.source)), copy.deepcopy(self.position)) # embedding table of ASDL production rules (constructors), one for each ApplyConstructor action, # the last entry is the embedding for Reduce action self.action_embed_size = args.action_embed_size self.field_embed_size = args.field_embed_size self.type_embed_size = args.type_embed_size assert self.d_model == (self.action_embed_size + self.action_embed_size * (not self.args.no_parent_production_embed) + self.field_embed_size * (not self.args.no_parent_field_embed) + self.type_embed_size * (not self.args.no_parent_field_type_embed)) self.production_embed = Embeddings(self.action_embed_size, len(transition_system.grammar) + 1) # embedding table for target primitive tokens self.primitive_embed = Embeddings(self.action_embed_size, len(vocab.primitive)) # embedding table for ASDL fields in constructors self.field_embed = Embeddings(self.field_embed_size, len(transition_system.grammar.fields)) # embedding table for ASDL types self.type_embed = Embeddings(self.type_embed_size, len(transition_system.grammar.types)) assert args.lstm == "transformer" self.encoder = Encoder( EncoderLayer(self.d_model, copy.deepcopy(attn), copy.deepcopy(ff), self.dropout), self.num_layers).to(self.device) self.decoder = Decoder( DecoderLayer(self.d_model, copy.deepcopy(parent_attn), copy.deepcopy(attn), copy.deepcopy(ff), self.dropout), self.num_layers, ).to(self.device) if args.no_copy is False: # pointer net for copying tokens from source side self.src_pointer_net = PointerNet( query_vec_size=args.att_vec_size, src_encoding_size=args.hidden_size) # given the decoder's hidden state, predict whether to copy or generate a target primitive token # output: [p(gen(token)) | s_t, p(copy(token)) | s_t] self.primitive_predictor = nn.Linear(args.att_vec_size, 2) if args.primitive_token_label_smoothing: self.label_smoothing = LabelSmoothing( args.primitive_token_label_smoothing, len(self.vocab.primitive), ignore_indices=[0, 1, 2]) # initialize the decoder's state and cells with encoder hidden states self.decoder_cell_init = nn.Linear(args.hidden_size, args.hidden_size) # attention: dot product attention # project source encoding to decoder rnn's hidden space self.att_src_linear = nn.Linear(args.hidden_size, args.hidden_size, bias=False) # transformation of decoder hidden states and context vectors before reading out target words # this produces the `attentional vector` in (Luong et al., 2015) self.att_vec_linear = nn.Linear(args.hidden_size + args.hidden_size, args.att_vec_size, bias=False) # bias for predicting ApplyConstructor and GenToken actions self.production_readout_b = nn.Parameter( torch.zeros(len(transition_system.grammar) + 1, dtype=torch.float32)) self.tgt_token_readout_b = nn.Parameter( torch.zeros(len(vocab.primitive), dtype=torch.float32)) if args.no_query_vec_to_action_map: # if there is no additional linear layer between the attentional vector (i.e., the query vector) # and the final softmax layer over target actions, we use the attentional vector to compute action # probabilities assert args.att_vec_size == args.action_embed_size self.production_readout = lambda q: F.linear( q * math.sqrt(self.d_model), self.production_embed.lut.weight, self.production_readout_b) self.tgt_token_readout = lambda q: F.linear( q * math.sqrt(self.d_model), self.primitive_embed.lut.weight, self.tgt_token_readout_b) else: # by default, we feed the attentional vector (i.e., the query vector) into a linear layer without bias, and # compute action probabilities by dot-producting the resulting vector and (GenToken, ApplyConstructor) action embeddings # i.e., p(action) = query_vec^T \cdot W \cdot embedding self.query_vec_to_action_embed = nn.Linear( args.att_vec_size, args.action_embed_size, bias=args.readout == "non_linear") if args.query_vec_to_action_diff_map: # use different linear transformations for GenToken and ApplyConstructor actions self.query_vec_to_primitive_embed = nn.Linear( args.att_vec_size, args.action_embed_size, bias=args.readout == "non_linear") else: self.query_vec_to_primitive_embed = self.query_vec_to_action_embed self.read_out_act = F.tanh if args.readout == "non_linear" else nn_utils.identity self.production_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_action_embed(q)) * math. sqrt(self.d_model), self.production_embed.lut.weight, self.production_readout_b, ) self.tgt_token_readout = lambda q: F.linear( self.read_out_act(self.query_vec_to_primitive_embed(q)) * math. sqrt(self.d_model), self.primitive_embed.lut.weight, self.tgt_token_readout_b, ) # dropout layer self.dropout = nn.Dropout(args.dropout) for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)