示例#1
0
    def forward(self,
                variable,
                condition,
                variable_lengths=None,
                condition_lengths=None):
        """
        Args:
            variable [batch_size, var_seq_len, hidden_dim]: encoding of variable
            variable_lengts np.array[batch_size]: lengths of each sequence in the variables
            condition [batch_size, cond_seq_len, hidden_dim]: encoding of what we want to condition on
            condition_lengths np.array[batch_size]: lengths of each sequence in the condition variable
        Returns:
            H_var_cond [batch_size, hidden_dim]
        """
        # If the lengths is not given for any of the sequences, we assume that there is no mask
        if variable_lengths is None and condition_lengths is None:
            mask_var_cond = None
            mask_cond = None
        # We might not need masking for the condition, eg if seq_len=1
        elif condition_lengths is None:
            mask_var_cond = length_to_mask(variable_lengths).to(
                variable.device)
            mask_cond = None
        else:
            mask_var_cond = length_to_mask(
                variable_lengths, condition_lengths).to(
                    variable.device)  # [batch_size, var_seq_len, cond_seq_len]
            mask_cond = length_to_mask(condition_lengths).to(
                variable.device)  # [batch_size, cond_seq_len, 1]
        # TODO: what if we only have the lengths for the condition?

        # Run attention conditioned on the column embeddings
        H_var_cond, _ = self.attention(
            variable, key=condition,
            mask=mask_var_cond)  # [batch_size, num_cols_in_db, hidden_dim]

        if self.use_bag_of_word:
            # Use Bag of Words to remove column length
            H_var_cond, _ = self.bag_of_word(
                H_var_cond, mask=mask_cond)  # [batch_size, 1, hidden_dim]
            H_var_cond = H_var_cond.squeeze(1)  # [batch_size, hidden_dim]

        # Project embedding
        H_var_cond = self.W(
            H_var_cond
        )  # [batch_size, num_cols_in_db, hidden_dim] or [batch_size, hidden_dim]

        return H_var_cond
示例#2
0
    def forward(self, inp, seq_len):
        (bi_awd_hid, bi_awd_hid_rev), _, _ = self.encoder(inp, seq_len)

        # Add backward and forward, and mean over sequence
        output = torch.cat(((bi_awd_hid[0] + bi_awd_hid[1]),
                            (bi_awd_hid_rev[0] + bi_awd_hid_rev[1])),
                           dim=2)  #(seq_len, bs, 2560)
        del bi_awd_hid, bi_awd_hid_rev
        mask = length_to_mask(seq_len).unsqueeze(2)  # (bs, seq_len, 1)
        output = output.permute(1, 0, 2) * mask  #(bs, seq_len, 2560)
        output = output.sum(1) / mask.sum(1)  # (bs, 2560)

        output = self.decoder(output)
        return output
示例#3
0
 def forward(self, variable, lengths):
     """
     Args:
         variable [batch_size, seq_len, hidden_dim]: embedding of the sequences
         lengths np.array[batch_size]: lengths of each sequence in the variable
     Returns:
         context [batch_size, hidden_dim]: masked mean over the sequence length
     """
     mask = length_to_mask(lengths)
     mask = mask.to(variable.device)
     # Calculate masked mean using uniform attention
     context, _ = self.attention(variable,
                                 mask=mask)  # [batch_size, 1, hidden_dim]
     context = context.squeeze(1)  # [batch_size, hidden_dim]
     return context
示例#4
0
 def forward(self,
             variable,
             condition,
             variable_lengths=None,
             condition_lengths=None):
     if variable_lengths is None and condition_lengths is None:
         mask_var_cond = None
         mask_cond = None
     elif condition_lengths is None:
         mask_var_cond = length_to_mask(variable_lengths).to(
             variable.device)
         mask_cond = None
     else:
         mask_var_cond = length_to_mask(
             variable_lengths, condition_lengths).to(variable.device)
         mask_cond = length_to_mask(condition_lengths).to(variable.device)
     H_var_cond, _ = self.attention(variable,
                                    key=condition,
                                    mask=mask_var_cond)
     if self.use_bag_of_word:
         H_var_cond, _ = self.bag_of_word(H_var_cond, mask=mask_cond)
         H_var_cond = H_var_cond.squeeze(1)
     H_var_cond = self.W(H_var_cond)
     return H_var_cond
示例#5
0
    def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len):
        """
        Args:
            q_emb_var [batch_size, question_seq_len, embedding_dim] : embedding of question
            q_len [batch_size] : lengths of questions
            hs_emb_var [batch_size, history_seq_len, embedding_dim] : embedding of history
            hs_len [batch_size] : lengths of history
            col_emb_var [batch_size*num_cols_in_db, col_name_len, embedding_dim] : embedding of history
            col_len [batch_size] : number of columns for each query
            col_name_len [batch_size*num_cols_in_db] : number of tokens for each column name. 
                                        Each column has infomation about [type, table, name_token1, name_token2,...]
        Returns:
            num_cols = [batch_size, max_num_cols] : probability distribution over how many columns should be predicted
            p_col [batch_size, num_columns_in_db] : probability distribution over the columns given
        """
        batch_size = len(col_len)

        q_enc,_ = self.q_lstm(q_emb_var, q_len)  # [batch_size, question_seq_len, hidden_dim]
        hs_enc,_ = self.hs_lstm(hs_emb_var, hs_len)  # [batch_size, history_seq_len, hidden_dim]
        _, col_enc = self.col_lstm(col_emb_var, col_name_len) # [batch_size*num_cols_in_db, hidden_dim]
        col_enc = col_enc.reshape(batch_size, col_len.max(), self.hidden_dim) # [batch_size, num_cols_in_db, hidden_dim]

        #############################
        # Predict number of tokens  #
        #############################

        # Run conditional encoding for column|question, and history|question
        H_col_q = self.col_q_num(col_enc, q_enc, col_len, q_len)  # [batch_size, hidden_dim]
        H_hs_q = self.hs_q_num(hs_enc, q_enc, hs_len, q_len)  # [batch_size, hidden_dim]

        num_tokens = self.tokens_num_out(H_col_q + int(self.use_hs)*H_hs_q)

        ################################
        # Predict start index of token #
        ################################

        # Run conditional encoding for question|column, and history|column
        H_col_q = self.col_q(col_enc, q_enc, col_len, q_len)  # [batch_size, question_seq_len, hidden_dim]
        H_hs_q = self.hs_q(hs_enc, q_enc, hs_len, q_len)  # [batch_size, question_seq_len, hidden_dim]
        H_value = self.W_value(q_enc)  # [batch_size, question_seq_len, hidden_dim]

        values = self.value_out(H_col_q + int(self.use_hs)*H_hs_q + H_value).squeeze(2)  # [batch_size, question_seq_len]
        values_mask = length_to_mask(q_len).squeeze(2).to(values.device)

        # Number of tokens might be different for each question, so we need to mask some of them
        values = values.masked_fill_(values_mask, self.value_pad_token)

        return (num_tokens, values)
示例#6
0
 def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len):
     batch_size = len(col_len)
     q_enc,_ = self.q_lstm(q_emb_var, q_len) 
     hs_enc,_ = self.hs_lstm(hs_emb_var, hs_len) 
     _, col_enc = self.col_lstm(col_emb_var, col_name_len) 
     col_enc = col_enc.reshape(batch_size, col_len.max(), self.hidden_dim) 
     # predicting the num
     H_col_q = self.col_q_num(col_enc, q_enc, col_len, q_len)
     H_hs_q = self.hs_q_num(hs_enc, q_enc, hs_len, q_len)
     num_tokens = self.tokens_num_out(H_col_q + int(self.use_hs)*H_hs_q)
     # predicting the value
     H_col_q = self.col_q(col_enc, q_enc, col_len, q_len)
     H_hs_q = self.hs_q(hs_enc, q_enc, hs_len, q_len)  
     H_value = self.W_value(q_enc)  
     values = self.value_out(H_col_q + int(self.use_hs)*H_hs_q + H_value).squeeze(2) 
     values_mask = length_to_mask(q_len).squeeze(2).to(values.device)
     values = values.masked_fill_(values_mask, self.value_pad_token)
     return (num_tokens, values)
示例#7
0
 def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, col_emb_var, col_len, col_name_len):
     batch_size = len(col_len)
     q_enc, _ = self.q_lstm(q_emb_var, q_len)
     hs_enc, _ = self.hs_lstm(hs_emb_var, hs_len)
     _, col_enc = self.col_lstm(col_emb_var, col_name_len)
     col_enc = col_enc.reshape(batch_size, col_len.max(), self.hidden_dim)
     # predict number of columns
     H_q_col = self.q_col_num(q_enc, col_enc, q_len, col_len)
     H_hs_col = self.hs_col_num(hs_enc, col_enc, hs_len, col_len)
     num_cols = self.col_num_out(H_q_col + int(self.use_hs)*H_hs_col)
     num_reps = self.col_rep_out(H_q_col + int(self.use_hs)*H_hs_col)
     # predicting columns 
     H_q_col = self.q_col(q_enc, col_enc, q_len, col_len)
     H_hs_col = self.hs_col(hs_enc, col_enc, hs_len, col_len)
     H_col = self.W_col(col_enc)
     cols = self.col_out(H_q_col + int(self.use_hs) * H_hs_col + H_col).squeeze(2)
     col_mask = length_to_mask(col_len).squeeze(2).to(cols.device)
     cols = cols.masked_fill_(col_mask, self.col_pad_token)
     return num_cols, num_reps, cols
示例#8
0
 def forward(self, variable, lengths):
     mask = length_to_mask(lengths)
     mask = mask.to(variable.device)
     context, _ = self.attention(variable, mask=mask)
     context = context.squeeze(1)
     return context