def __init__( self, device, num_layers, num_heads, hidden_size, ff_size=None, dropout=0.1, merge_types=False, tie_layers=False, qq_max_dist=2, # qc_token_match=True, # qt_token_match=True, # cq_token_match=True, cc_foreign_key=True, cc_table_match=True, cc_max_dist=2, ct_foreign_key=True, ct_table_match=True, # tq_token_match=True, tc_table_match=True, tc_foreign_key=True, tt_max_dist=2, tt_foreign_key=True, sc_link=False, cv_link=False, cv_token_link=False, cv_token_start_link=False, original_bert_config=None): super().__init__() self._device = device self.num_heads = num_heads self.qq_max_dist = qq_max_dist # self.qc_token_match = qc_token_match # self.qt_token_match = qt_token_match # self.cq_token_match = cq_token_match self.cc_foreign_key = cc_foreign_key self.cc_table_match = cc_table_match self.cc_max_dist = cc_max_dist self.ct_foreign_key = ct_foreign_key self.ct_table_match = ct_table_match # self.tq_token_match = tq_token_match self.tc_table_match = tc_table_match self.tc_foreign_key = tc_foreign_key self.tt_max_dist = tt_max_dist self.tt_foreign_key = tt_foreign_key self.cv_token_start_link = cv_token_start_link self.cv_link = cv_link self.relation_ids = {} def add_relation(name): self.relation_ids[name] = len(self.relation_ids) def add_rel_dist(name, max_dist): for i in range(-max_dist, max_dist + 1): add_relation((name, i)) add_rel_dist('qq_dist', qq_max_dist) add_relation('qc_default') # if qc_token_match: # add_relation('qc_token_match') add_relation('qt_default') # if qt_token_match: # add_relation('qt_token_match') add_relation('cq_default') # if cq_token_match: # add_relation('cq_token_match') add_relation('cc_default') if cc_foreign_key: add_relation('cc_foreign_key_forward') add_relation('cc_foreign_key_backward') if cc_table_match: add_relation('cc_table_match') add_rel_dist('cc_dist', cc_max_dist) add_relation('ct_default') if ct_foreign_key: add_relation('ct_foreign_key') if ct_table_match: add_relation('ct_primary_key') add_relation('ct_table_match') add_relation('ct_any_table') add_relation('tq_default') # if cq_token_match: # add_relation('tq_token_match') add_relation('tc_default') if tc_table_match: add_relation('tc_primary_key') add_relation('tc_table_match') add_relation('tc_any_table') if tc_foreign_key: add_relation('tc_foreign_key') add_relation('tt_default') if tt_foreign_key: add_relation('tt_foreign_key_forward') add_relation('tt_foreign_key_backward') add_relation('tt_foreign_key_both') add_rel_dist('tt_dist', tt_max_dist) # schema linking relations # forward_backward if sc_link: add_relation('qcCEM') add_relation('cqCEM') add_relation('qtTEM') add_relation('tqTEM') add_relation('qcCPM') add_relation('cqCPM') add_relation('qtTPM') add_relation('tqTPM') if cv_link: add_relation("qcNUMBER") add_relation("cqNUMBER") add_relation("qcTIME") add_relation("cqTIME") add_relation("qcCELLMATCH") add_relation("cqCELLMATCH") if cv_token_link: add_relation("qcCELLTOKENMATCH") add_relation("cqCELLTOKENMATCH") else: self.relation_ids['qcCELLTOKENMATCH'] = self.relation_ids[ 'qcCELLMATCH'] self.relation_ids['cqCELLTOKENMATCH'] = self.relation_ids[ 'cqCELLMATCH'] if cv_token_start_link: add_relation("qcCELLMATCHSTART") add_relation("cqCELLMATCHSTART") else: self.relation_ids['qcCELLMATCHSTART'] = self.relation_ids[ 'qcCELLMATCH'] self.relation_ids['cqCELLMATCHSTART'] = self.relation_ids[ 'cqCELLMATCH'] if merge_types: assert not cc_foreign_key assert not cc_table_match assert not ct_foreign_key assert not ct_table_match assert not tc_foreign_key assert not tc_table_match assert not tt_foreign_key assert cc_max_dist == qq_max_dist assert tt_max_dist == qq_max_dist add_relation('xx_default') self.relation_ids['qc_default'] = self.relation_ids['xx_default'] self.relation_ids['qt_default'] = self.relation_ids['xx_default'] self.relation_ids['cq_default'] = self.relation_ids['xx_default'] self.relation_ids['cc_default'] = self.relation_ids['xx_default'] self.relation_ids['ct_default'] = self.relation_ids['xx_default'] self.relation_ids['tq_default'] = self.relation_ids['xx_default'] self.relation_ids['tc_default'] = self.relation_ids['xx_default'] self.relation_ids['tt_default'] = self.relation_ids['xx_default'] if sc_link: self.relation_ids['qcCEM'] = self.relation_ids['xx_default'] self.relation_ids['qcCPM'] = self.relation_ids['xx_default'] self.relation_ids['qtTEM'] = self.relation_ids['xx_default'] self.relation_ids['qtTPM'] = self.relation_ids['xx_default'] self.relation_ids['cqCEM'] = self.relation_ids['xx_default'] self.relation_ids['cqCPM'] = self.relation_ids['xx_default'] self.relation_ids['tqTEM'] = self.relation_ids['xx_default'] self.relation_ids['tqTPM'] = self.relation_ids['xx_default'] if cv_link: self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["qcTIME"] = self.relation_ids['xx_default'] self.relation_ids["cqTIME"] = self.relation_ids['xx_default'] self.relation_ids["qcCELLMATCH"] = self.relation_ids[ 'xx_default'] self.relation_ids["cqCELLMATCH"] = self.relation_ids[ 'xx_default'] for i in range(-qq_max_dist, qq_max_dist + 1): self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i] self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i] if ff_size is None: ff_size = hidden_size * 4 num_relations = len( set([r_id for _, r_id in self.relation_ids.items()])) self.encoder = transformer.Encoder( lambda: transformer.EncoderLayer( hidden_size, transformer.MultiHeadedAttentionWithRelations( num_heads, hidden_size, original_bert_config. attention_probs_dropout_prob), transformer.PositionwiseFeedForward( hidden_size, ff_size, original_bert_config. hidden_dropout_prob), num_relations, original_bert_config. hidden_dropout_prob), hidden_size, num_layers, tie_layers) self.align_attn = transformer.PointerWithRelations( hidden_size, num_relations, original_bert_config.attention_probs_dropout_prob) self.dropout = torch.nn.Dropout( original_bert_config.hidden_dropout_prob)
def __init__(self, device, num_layers, num_heads, hidden_size, num_utterance_keep, # 8, 8, 768 ff_size=None, dropout=0.1, merge_types=False, tie_layers=False, qq_max_dist=2, # qc_token_match=True, # qt_token_match=True, # cq_token_match=True, cc_foreign_key=True, cc_table_match=True, cc_max_dist=2, ct_foreign_key=True, ct_table_match=True, # tq_token_match=True, tc_table_match=True, tc_foreign_key=True, tt_max_dist=2, tt_foreign_key=True, sc_link=False, #True cv_link=False, #True ): super().__init__() self._device = device self.num_heads = num_heads self.num_utterance_keep = num_utterance_keep self.qq_max_dist = qq_max_dist # self.qc_token_match = qc_token_match # self.qt_token_match = qt_token_match # self.cq_token_match = cq_token_match self.cc_foreign_key = cc_foreign_key self.cc_table_match = cc_table_match self.cc_max_dist = cc_max_dist self.ct_foreign_key = ct_foreign_key self.ct_table_match = ct_table_match # self.tq_token_match = tq_token_match self.tc_table_match = tc_table_match self.tc_foreign_key = tc_foreign_key self.tt_max_dist = tt_max_dist self.tt_foreign_key = tt_foreign_key self.relation_ids = {} def add_relation(name): self.relation_ids[name] = len(self.relation_ids) def add_rel_dist(name, max_dist): for i in range(-max_dist, max_dist + 1): if isinstance(name, tuple): add_relation(name + (i,)) else: add_relation((name, i)) for i, j in itertools.product(range(self.num_utterance_keep), repeat=2): if i != j: add_relation(("dqq_dist", i, j)) else: add_rel_dist(("eqq_dist", i), qq_max_dist) # add_rel_dist('qq_dist', qq_max_dist) for i in range(self.num_utterance_keep): add_relation(("qc_default", i)) # add_relation('qc_default') # if qc_token_match: # add_relation('qc_token_match') for i in range(self.num_utterance_keep): add_relation(("qt_default", i)) # add_relation('qt_default') # if qt_token_match: # add_relation('qt_token_match') for i in range(self.num_utterance_keep): add_relation(("cq_default", i)) # add_relation('cq_default') # if cq_token_match: # add_relation('cq_token_match') for i in range(self.num_utterance_keep): add_relation(("tq_default", i)) # add_relation('tq_default') # if cq_token_match: # add_relation('tq_token_match') add_relation('cc_default') if cc_foreign_key: add_relation('cc_foreign_key_forward') add_relation('cc_foreign_key_backward') if cc_table_match: add_relation('cc_table_match') add_rel_dist('cc_dist', cc_max_dist) add_relation('ct_default') if ct_foreign_key: add_relation('ct_foreign_key') if ct_table_match: add_relation('ct_primary_key') add_relation('ct_table_match') add_relation('ct_any_table') add_relation('tc_default') if tc_table_match: add_relation('tc_primary_key') add_relation('tc_table_match') add_relation('tc_any_table') if tc_foreign_key: add_relation('tc_foreign_key') add_relation('tt_default') if tt_foreign_key: add_relation('tt_foreign_key_forward') add_relation('tt_foreign_key_backward') add_relation('tt_foreign_key_both') add_rel_dist('tt_dist', tt_max_dist) # schema linking relations # forward_backward if sc_link: for i in range(self.num_utterance_keep): add_relation(('qcCEM', i)) add_relation(('cqCEM', i)) add_relation(('qtTEM', i)) add_relation(('tqTEM', i)) add_relation(('qcCPM', i)) add_relation(('cqCPM', i)) add_relation(('qtTPM', i)) add_relation(('tqTPM', i)) if cv_link: for i in range(self.num_utterance_keep): add_relation(("qcNUMBER", i)) add_relation(("cqNUMBER", i)) add_relation(("qcTIME", i)) add_relation(("cqTIME", i)) add_relation(("qcCELLMATCH", i)) add_relation(("cqCELLMATCH", i)) if merge_types: #这个一定要设为false assert not cc_foreign_key assert not cc_table_match assert not ct_foreign_key assert not ct_table_match assert not tc_foreign_key assert not tc_table_match assert not tt_foreign_key assert cc_max_dist == qq_max_dist assert tt_max_dist == qq_max_dist add_relation('xx_default') self.relation_ids['qc_default'] = self.relation_ids['xx_default'] self.relation_ids['qt_default'] = self.relation_ids['xx_default'] self.relation_ids['cq_default'] = self.relation_ids['xx_default'] self.relation_ids['cc_default'] = self.relation_ids['xx_default'] self.relation_ids['ct_default'] = self.relation_ids['xx_default'] self.relation_ids['tq_default'] = self.relation_ids['xx_default'] self.relation_ids['tc_default'] = self.relation_ids['xx_default'] self.relation_ids['tt_default'] = self.relation_ids['xx_default'] if sc_link: self.relation_ids['qcCEM'] = self.relation_ids['xx_default'] self.relation_ids['qcCPM'] = self.relation_ids['xx_default'] self.relation_ids['qtTEM'] = self.relation_ids['xx_default'] self.relation_ids['qtTPM'] = self.relation_ids['xx_default'] self.relation_ids['cqCEM'] = self.relation_ids['xx_default'] self.relation_ids['cqCPM'] = self.relation_ids['xx_default'] self.relation_ids['tqTEM'] = self.relation_ids['xx_default'] self.relation_ids['tqTPM'] = self.relation_ids['xx_default'] if cv_link: self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["qcTIME"] = self.relation_ids['xx_default'] self.relation_ids["cqTIME"] = self.relation_ids['xx_default'] self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default'] self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default'] for i in range(-qq_max_dist, qq_max_dist + 1): self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i] self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i] if ff_size is None: ff_size = hidden_size * 4 self.encoder = transformer.Encoder( lambda: transformer.EncoderLayer( hidden_size, transformer.MultiHeadedAttentionWithRelations( num_heads, hidden_size, dropout), transformer.PositionwiseFeedForward( hidden_size, ff_size, dropout), len(self.relation_ids), dropout), hidden_size, num_layers, tie_layers) #FALSE self.align_attn = transformer.PointerWithRelations(hidden_size, len(self.relation_ids), dropout)
def __init__( self, device, num_layers, num_heads, hidden_size, ff_size=None, dropout=0.1, merge_types=False, tie_layers=False, qq_max_dist=2, # qc_token_match=True, # qt_token_match=True, # cq_token_match=True, cc_foreign_key=True, cc_table_match=True, cc_max_dist=2, ct_foreign_key=True, ct_table_match=True, # tq_token_match=True, tc_table_match=True, tc_foreign_key=True, tt_max_dist=2, tt_foreign_key=True, sc_link=False, cv_link=False, ): super().__init__() self._device = device self.num_heads = num_heads self.qq_max_dist = qq_max_dist # self.qc_token_match = qc_token_match # self.qt_token_match = qt_token_match # self.cq_token_match = cq_token_match self.cc_foreign_key = cc_foreign_key self.cc_table_match = cc_table_match self.cc_max_dist = cc_max_dist self.ct_foreign_key = ct_foreign_key self.ct_table_match = ct_table_match # self.tq_token_match = tq_token_match self.tc_table_match = tc_table_match self.tc_foreign_key = tc_foreign_key self.tt_max_dist = tt_max_dist self.tt_foreign_key = tt_foreign_key self.relation_ids = {} # relation 2 id def add_relation(name): self.relation_ids[name] = len(self.relation_ids) def add_rel_dist(name, max_dist): for i in range(-max_dist, max_dist + 1): add_relation((name, i)) add_rel_dist('qq_dist', qq_max_dist) add_relation('qc_default') # question column # if qc_token_match: # add_relation('qc_token_match') add_relation('qt_default') # question table # if qt_token_match: # add_relation('qt_token_match') add_relation('cq_default') # column question # if cq_token_match: # add_relation('cq_token_match') add_relation('cc_default') # column column if cc_foreign_key: add_relation('cc_foreign_key_forward') # column column add_relation('cc_foreign_key_backward') if cc_table_match: add_relation('cc_table_match') add_rel_dist('cc_dist', cc_max_dist) add_relation('ct_default') # column table if ct_foreign_key: add_relation('ct_foreign_key') if ct_table_match: add_relation('ct_primary_key') add_relation('ct_table_match') add_relation('ct_any_table') add_relation('tq_default') # table question # if cq_token_match: # add_relation('tq_token_match') add_relation('tc_default') # table column if tc_table_match: add_relation('tc_primary_key') add_relation('tc_table_match') add_relation('tc_any_table') if tc_foreign_key: add_relation('tc_foreign_key') add_relation('tt_default') # table table if tt_foreign_key: add_relation('tt_foreign_key_forward') add_relation('tt_foreign_key_backward') add_relation('tt_foreign_key_both') add_rel_dist('tt_dist', tt_max_dist) # schema linking relations # forward_backward if sc_link: add_relation('qcCEM') add_relation('cqCEM') add_relation('qtTEM') add_relation('tqTEM') add_relation('qcCPM') add_relation('cqCPM') add_relation('qtTPM') add_relation('tqTPM') if cv_link: add_relation("qcNUMBER") add_relation("cqNUMBER") add_relation("qcTIME") add_relation("cqTIME") add_relation("qcCELLMATCH") add_relation("cqCELLMATCH") if merge_types: assert not cc_foreign_key assert not cc_table_match assert not ct_foreign_key assert not ct_table_match assert not tc_foreign_key assert not tc_table_match assert not tt_foreign_key assert cc_max_dist == qq_max_dist assert tt_max_dist == qq_max_dist add_relation('xx_default') self.relation_ids['qc_default'] = self.relation_ids['xx_default'] self.relation_ids['qt_default'] = self.relation_ids['xx_default'] self.relation_ids['cq_default'] = self.relation_ids['xx_default'] self.relation_ids['cc_default'] = self.relation_ids['xx_default'] self.relation_ids['ct_default'] = self.relation_ids['xx_default'] self.relation_ids['tq_default'] = self.relation_ids['xx_default'] self.relation_ids['tc_default'] = self.relation_ids['xx_default'] self.relation_ids['tt_default'] = self.relation_ids['xx_default'] if sc_link: self.relation_ids['qcCEM'] = self.relation_ids['xx_default'] self.relation_ids['qcCPM'] = self.relation_ids['xx_default'] self.relation_ids['qtTEM'] = self.relation_ids['xx_default'] self.relation_ids['qtTPM'] = self.relation_ids['xx_default'] self.relation_ids['cqCEM'] = self.relation_ids['xx_default'] self.relation_ids['cqCPM'] = self.relation_ids['xx_default'] self.relation_ids['tqTEM'] = self.relation_ids['xx_default'] self.relation_ids['tqTPM'] = self.relation_ids['xx_default'] if cv_link: self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default'] self.relation_ids["qcTIME"] = self.relation_ids['xx_default'] self.relation_ids["cqTIME"] = self.relation_ids['xx_default'] self.relation_ids["qcCELLMATCH"] = self.relation_ids[ 'xx_default'] self.relation_ids["cqCELLMATCH"] = self.relation_ids[ 'xx_default'] for i in range(-qq_max_dist, qq_max_dist + 1): self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i] self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i] if ff_size is None: ff_size = hidden_size * 4 self.encoder = transformer.Encoder( lambda: transformer.EncoderLayer( hidden_size, transformer.MultiHeadedAttentionWithRelations( num_heads, hidden_size, dropout), transformer.PositionwiseFeedForward(hidden_size, ff_size, dropout), len(self.relation_ids), dropout), hidden_size, num_layers, tie_layers) self.align_attn = transformer.PointerWithRelations( hidden_size, len(self.relation_ids), dropout)