def score(self, context, context_mask, utt_pointer, batch): B = len(batch) # reps for flattened columns bert = self.bert_dropout(self.utt_bert_embedder(context)[0]) cand_trans, _ = rnn.run_rnn(self.utt_cand_trans, bert, context_mask.sum(1).long()) enc_trans, _ = rnn.run_rnn(self.utt_enc_trans, bert, context_mask.sum(1).long()) cand = torch.cat([ self.utt_emb.weight.unsqueeze(0).repeat(B, 1, 1), cand_trans, ], dim=1) cand_mask = torch.cat([ torch.ones(B, len(self.utt_vocab)).float().to(self.device), context_mask, ], dim=1) utt_dec = self.utt_pointer_decoder.forward( emb=self.dropout(cand), emb_mask=cand_mask, enc=self.dropout(enc_trans), enc_mask=context_mask.float(), state0=None, gt=utt_pointer, max_len=self.args.max_query_len, batch=batch, ) normed = torch.log_softmax(utt_dec, dim=2) eos = self.utt_vocab.word2index('EOS') scores = [] for score, inds, ex in zip(normed, utt_pointer, batch): valid = inds.tolist() if eos in valid: valid = valid[:valid.index(eos) + 1] greedy = score.max(1)[1].tolist() if eos in greedy: greedy = greedy[:greedy.index(eos) + 1] greedy_score = score.max(1)[0].sum() score_sum = sum([score[i, j].item() for i, j in enumerate(valid)]) / len(valid) scores.append(score_sum) return scores
def forward(self, context, context_mask, utt_pointer, batch): B = len(batch) # reps for flattened columns bert = self.bert_dropout(self.utt_bert_embedder(context)[0]) cand_trans, _ = rnn.run_rnn(self.utt_cand_trans, bert, context_mask.sum(1).long()) enc_trans, _ = rnn.run_rnn(self.utt_enc_trans, bert, context_mask.sum(1).long()) cand = torch.cat([ self.utt_emb.weight.unsqueeze(0).repeat(B, 1, 1), cand_trans, ], dim=1) cand_mask = torch.cat([ torch.ones(B, len(self.utt_vocab)).float().to(self.device), context_mask, ], dim=1) if not self.should_beam_search(): utt_dec = self.utt_pointer_decoder( emb=self.dropout(cand), emb_mask=cand_mask, enc=self.dropout(enc_trans), enc_mask=context_mask.float(), state0=None, gt=utt_pointer if self.training else None, max_len=self.args.max_query_len, batch=batch, ) else: utt_dec = self.utt_pointer_decoder.beam_search( emb=self.dropout(cand), emb_mask=cand_mask, enc=self.dropout(enc_trans), enc_mask=context_mask.float(), eos_ind=self.utt_vocab.word2index('EOS'), max_len=self.args.max_query_len, batch=batch, beam_size=self.args.beam_size, ) return dict(utt_dec=utt_dec)
def forward(self, utterance, utterance_mask, tables, tables_mask, starts, ends, query_pointer, value_pointer, utt_tables, utt_tables_mask, utt_starts, utt_ends, utt_pointer, batch): B = len(batch) # reps for flattened columns col_reps = [] col_mask = [] # reps for each table table_reps = [] table_mask = [] for ids_table, mask_table, start_table, end_table in zip( tables, tables_mask, starts, ends): bert_table = self.bert_dropout(self.bert_embedder(ids_table)[0]) table_col_reps = [] for bert_col, start_col, end_col in zip(bert_table, start_table, end_table): cols = [bert_col[cs:ce] for cs, ce in zip(start_col, end_col)] mask = [torch.ones(len(e)) for e in cols] pad = nn.utils.rnn.pad_sequence(cols, batch_first=True, padding_value=0) mask = nn.utils.rnn.pad_sequence(mask, batch_first=True, padding_value=0).float().to( self.device) # compute selfattn for this column scores = self.col_sa_scorer(pad).squeeze(2) normalized_scores = F.softmax(scores - (1 - mask) * 1e20, dim=1) col_sa = pad.mul( normalized_scores.unsqueeze(2).expand_as(pad)).sum(1) table_col_reps.append(col_sa) table_col_reps = torch.cat(table_col_reps, dim=0) col_reps.append(table_col_reps) col_mask.append(torch.ones(len(table_col_reps))) # compute selfattn for this talbe scores = self.table_sa_scorer(bert_table).squeeze(2) normalized_scores = F.softmax(scores - (1 - mask_table) * 1e20, dim=1) tab_sa = bert_table.mul( normalized_scores.unsqueeze(2).expand_as(bert_table)).sum(1) table_reps.append(tab_sa) table_mask.append(torch.ones(len(tab_sa))) col_reps = nn.utils.rnn.pad_sequence(col_reps, batch_first=True, padding_value=0) col_mask = nn.utils.rnn.pad_sequence(col_mask, batch_first=True, padding_value=0).to(self.device) table_reps = nn.utils.rnn.pad_sequence(table_reps, batch_first=True, padding_value=0) table_mask = nn.utils.rnn.pad_sequence(table_mask, batch_first=True, padding_value=0).to(self.device) col_trans, _ = rnn.run_rnn(self.col_trans, col_reps, col_mask.sum(1).long()) table_trans, _ = rnn.run_rnn(self.table_trans, table_reps, table_mask.sum(1).long()) table_trans = self.dropout(table_trans) cand = self.dropout( torch.cat([ self.sql_emb.weight.unsqueeze(0).repeat(B, 1, 1), col_trans, ], dim=1)) cand_mask = torch.cat([ torch.ones(B, len(self.sql_vocab)).float().to(self.device), col_mask ], dim=1) query_dec = self.pointer_decoder( emb=cand, emb_mask=cand_mask, enc=table_trans, enc_mask=table_mask.float(), state0=None, gt=query_pointer if self.training else None, max_len=self.args.max_query_len, batch=batch, ) utt = self.bert_dropout(self.value_bert_embedder(utterance)[0]) utt_trans, _ = rnn.run_rnn(self.utt_trans, utt, utterance_mask.sum(1).long()) cand = self.dropout( torch.cat([ self.sql_emb.weight.unsqueeze(0).repeat(B, 1, 1), utt_trans, ], dim=1)) cand_mask = torch.cat([ torch.ones(B, len(self.sql_vocab)).float().to(self.device), utterance_mask ], dim=1) value_dec = self.value_decoder( emb=cand, emb_mask=cand_mask, enc=utt, enc_mask=utterance_mask.float(), state0=None, gt=value_pointer if self.training else None, max_len=self.args.max_value_len, batch=batch, ) # reps for each table # reps for flattened columns col_reps = [] col_mask = [] # reps for each table table_reps = [] table_mask = [] for ids_table, mask_table, start_table, end_table in zip( utt_tables, utt_tables_mask, utt_starts, utt_ends): bert_table = self.bert_dropout( self.utt_bert_embedder(ids_table)[0]) table_col_reps = [] for bert_col, start_col, end_col in zip(bert_table, start_table, end_table): cols = [bert_col[cs:ce] for cs, ce in zip(start_col, end_col)] mask = [torch.ones(len(e)) for e in cols] pad = nn.utils.rnn.pad_sequence(cols, batch_first=True, padding_value=0) mask = nn.utils.rnn.pad_sequence(mask, batch_first=True, padding_value=0).float().to( self.device) # compute selfattn for this column scores = self.utt_col_sa_scorer(pad).squeeze(2) normalized_scores = F.softmax(scores - (1 - mask) * 1e20, dim=1) col_sa = pad.mul( normalized_scores.unsqueeze(2).expand_as(pad)).sum(1) table_col_reps.append(col_sa) table_col_reps = torch.cat(table_col_reps, dim=0) col_reps.append(table_col_reps) col_mask.append(torch.ones(len(table_col_reps))) # compute selfattn for this talbe scores = self.utt_table_sa_scorer(bert_table).squeeze(2) normalized_scores = F.softmax(scores - (1 - mask_table) * 1e20, dim=1) tab_sa = bert_table.mul( normalized_scores.unsqueeze(2).expand_as(bert_table)).sum(1) table_reps.append(tab_sa) table_mask.append(torch.ones(len(tab_sa))) col_reps = nn.utils.rnn.pad_sequence(col_reps, batch_first=True, padding_value=0) col_mask = nn.utils.rnn.pad_sequence(col_mask, batch_first=True, padding_value=0).to(self.device) table_reps = nn.utils.rnn.pad_sequence(table_reps, batch_first=True, padding_value=0) table_mask = nn.utils.rnn.pad_sequence(table_mask, batch_first=True, padding_value=0).to(self.device) col_trans, _ = rnn.run_rnn(self.utt_col_trans, col_reps, col_mask.sum(1).long()) table_trans, _ = rnn.run_rnn(self.utt_table_trans, table_reps, table_mask.sum(1).long()) table_trans = self.dropout(table_trans) cand = self.dropout( torch.cat([ self.utt_emb.weight.unsqueeze(0).repeat(B, 1, 1), col_trans, ], dim=1)) cand_mask = torch.cat([ torch.ones(B, len(self.utt_vocab)).float().to(self.device), col_mask ], dim=1) utt_dec = self.utt_pointer_decoder( emb=cand, emb_mask=cand_mask, enc=table_trans, enc_mask=table_mask.float(), state0=None, gt=utt_pointer if self.training else None, max_len=self.args.max_query_len, batch=batch, ) return dict(query_dec=query_dec, value_dec=value_dec, utt_dec=utt_dec)