def __init__(
            self,
            device,
            num_layers,
            num_heads,
            hidden_size,
            ff_size=None,
            dropout=0.1,
            merge_types=False,
            tie_layers=False,
            qq_max_dist=2,
            # qc_token_match=True,
            # qt_token_match=True,
            # cq_token_match=True,
            cc_foreign_key=True,
            cc_table_match=True,
            cc_max_dist=2,
            ct_foreign_key=True,
            ct_table_match=True,
            # tq_token_match=True,
            tc_table_match=True,
            tc_foreign_key=True,
            tt_max_dist=2,
            tt_foreign_key=True,
            sc_link=False,
            cv_link=False,
            cv_token_link=False,
            cv_token_start_link=False,
            original_bert_config=None):
        super().__init__()
        self._device = device
        self.num_heads = num_heads

        self.qq_max_dist = qq_max_dist
        # self.qc_token_match = qc_token_match
        # self.qt_token_match = qt_token_match
        # self.cq_token_match = cq_token_match
        self.cc_foreign_key = cc_foreign_key
        self.cc_table_match = cc_table_match
        self.cc_max_dist = cc_max_dist
        self.ct_foreign_key = ct_foreign_key
        self.ct_table_match = ct_table_match
        # self.tq_token_match = tq_token_match
        self.tc_table_match = tc_table_match
        self.tc_foreign_key = tc_foreign_key
        self.tt_max_dist = tt_max_dist
        self.tt_foreign_key = tt_foreign_key

        self.cv_token_start_link = cv_token_start_link
        self.cv_link = cv_link
        self.relation_ids = {}

        def add_relation(name):
            self.relation_ids[name] = len(self.relation_ids)

        def add_rel_dist(name, max_dist):
            for i in range(-max_dist, max_dist + 1):
                add_relation((name, i))

        add_rel_dist('qq_dist', qq_max_dist)

        add_relation('qc_default')
        # if qc_token_match:
        #    add_relation('qc_token_match')

        add_relation('qt_default')
        # if qt_token_match:
        #    add_relation('qt_token_match')

        add_relation('cq_default')
        # if cq_token_match:
        #    add_relation('cq_token_match')

        add_relation('cc_default')
        if cc_foreign_key:
            add_relation('cc_foreign_key_forward')
            add_relation('cc_foreign_key_backward')
        if cc_table_match:
            add_relation('cc_table_match')
        add_rel_dist('cc_dist', cc_max_dist)

        add_relation('ct_default')
        if ct_foreign_key:
            add_relation('ct_foreign_key')
        if ct_table_match:
            add_relation('ct_primary_key')
            add_relation('ct_table_match')
            add_relation('ct_any_table')

        add_relation('tq_default')
        # if cq_token_match:
        #    add_relation('tq_token_match')

        add_relation('tc_default')
        if tc_table_match:
            add_relation('tc_primary_key')
            add_relation('tc_table_match')
            add_relation('tc_any_table')
        if tc_foreign_key:
            add_relation('tc_foreign_key')

        add_relation('tt_default')
        if tt_foreign_key:
            add_relation('tt_foreign_key_forward')
            add_relation('tt_foreign_key_backward')
            add_relation('tt_foreign_key_both')
        add_rel_dist('tt_dist', tt_max_dist)

        # schema linking relations
        # forward_backward
        if sc_link:
            add_relation('qcCEM')
            add_relation('cqCEM')
            add_relation('qtTEM')
            add_relation('tqTEM')
            add_relation('qcCPM')
            add_relation('cqCPM')
            add_relation('qtTPM')
            add_relation('tqTPM')

        if cv_link:
            add_relation("qcNUMBER")
            add_relation("cqNUMBER")
            add_relation("qcTIME")
            add_relation("cqTIME")
            add_relation("qcCELLMATCH")
            add_relation("cqCELLMATCH")
            if cv_token_link:
                add_relation("qcCELLTOKENMATCH")
                add_relation("cqCELLTOKENMATCH")
            else:
                self.relation_ids['qcCELLTOKENMATCH'] = self.relation_ids[
                    'qcCELLMATCH']
                self.relation_ids['cqCELLTOKENMATCH'] = self.relation_ids[
                    'cqCELLMATCH']
            if cv_token_start_link:
                add_relation("qcCELLMATCHSTART")
                add_relation("cqCELLMATCHSTART")
            else:
                self.relation_ids['qcCELLMATCHSTART'] = self.relation_ids[
                    'qcCELLMATCH']
                self.relation_ids['cqCELLMATCHSTART'] = self.relation_ids[
                    'cqCELLMATCH']

        if merge_types:
            assert not cc_foreign_key
            assert not cc_table_match
            assert not ct_foreign_key
            assert not ct_table_match
            assert not tc_foreign_key
            assert not tc_table_match
            assert not tt_foreign_key

            assert cc_max_dist == qq_max_dist
            assert tt_max_dist == qq_max_dist

            add_relation('xx_default')
            self.relation_ids['qc_default'] = self.relation_ids['xx_default']
            self.relation_ids['qt_default'] = self.relation_ids['xx_default']
            self.relation_ids['cq_default'] = self.relation_ids['xx_default']
            self.relation_ids['cc_default'] = self.relation_ids['xx_default']
            self.relation_ids['ct_default'] = self.relation_ids['xx_default']
            self.relation_ids['tq_default'] = self.relation_ids['xx_default']
            self.relation_ids['tc_default'] = self.relation_ids['xx_default']
            self.relation_ids['tt_default'] = self.relation_ids['xx_default']

            if sc_link:
                self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
                self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
            if cv_link:
                self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
                self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
                self.relation_ids["qcCELLMATCH"] = self.relation_ids[
                    'xx_default']
                self.relation_ids["cqCELLMATCH"] = self.relation_ids[
                    'xx_default']

            for i in range(-qq_max_dist, qq_max_dist + 1):
                self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist',
                                                                    i]
                self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist',
                                                                    i]

        if ff_size is None:
            ff_size = hidden_size * 4
        num_relations = len(
            set([r_id for _, r_id in self.relation_ids.items()]))
        self.encoder = transformer.Encoder(
            lambda: transformer.EncoderLayer(
                hidden_size,
                transformer.MultiHeadedAttentionWithRelations(
                    num_heads, hidden_size, original_bert_config.
                    attention_probs_dropout_prob),
                transformer.PositionwiseFeedForward(
                    hidden_size, ff_size, original_bert_config.
                    hidden_dropout_prob), num_relations, original_bert_config.
                hidden_dropout_prob), hidden_size, num_layers, tie_layers)

        self.align_attn = transformer.PointerWithRelations(
            hidden_size, num_relations,
            original_bert_config.attention_probs_dropout_prob)
        self.dropout = torch.nn.Dropout(
            original_bert_config.hidden_dropout_prob)
    def __init__(self, device, num_layers, num_heads, hidden_size, num_utterance_keep, # 8, 8, 768
                 ff_size=None,
                 dropout=0.1,
                 merge_types=False,
                 tie_layers=False,
                 qq_max_dist=2,
                 # qc_token_match=True,
                 # qt_token_match=True,
                 # cq_token_match=True,
                 cc_foreign_key=True,
                 cc_table_match=True,
                 cc_max_dist=2,
                 ct_foreign_key=True,
                 ct_table_match=True,
                 # tq_token_match=True,
                 tc_table_match=True,
                 tc_foreign_key=True,
                 tt_max_dist=2,
                 tt_foreign_key=True,
                 sc_link=False,         #True
                 cv_link=False,         #True
                 ):
        super().__init__()
        self._device = device
        self.num_heads = num_heads
        self.num_utterance_keep = num_utterance_keep

        self.qq_max_dist = qq_max_dist
        # self.qc_token_match = qc_token_match
        # self.qt_token_match = qt_token_match
        # self.cq_token_match = cq_token_match
        self.cc_foreign_key = cc_foreign_key
        self.cc_table_match = cc_table_match
        self.cc_max_dist = cc_max_dist
        self.ct_foreign_key = ct_foreign_key
        self.ct_table_match = ct_table_match
        # self.tq_token_match = tq_token_match
        self.tc_table_match = tc_table_match
        self.tc_foreign_key = tc_foreign_key
        self.tt_max_dist = tt_max_dist
        self.tt_foreign_key = tt_foreign_key

        self.relation_ids = {}

        def add_relation(name):
            self.relation_ids[name] = len(self.relation_ids)

        def add_rel_dist(name, max_dist):
            for i in range(-max_dist, max_dist + 1):
                if isinstance(name, tuple):
                    add_relation(name + (i,))
                else:
                    add_relation((name, i))

        for i, j in itertools.product(range(self.num_utterance_keep), repeat=2):
            if i != j:
                add_relation(("dqq_dist", i, j))
            else:
                add_rel_dist(("eqq_dist", i), qq_max_dist)
        # add_rel_dist('qq_dist', qq_max_dist)


        for i in range(self.num_utterance_keep):
            add_relation(("qc_default", i))
        # add_relation('qc_default')
        # if qc_token_match:
        #    add_relation('qc_token_match')

        for i in range(self.num_utterance_keep):
            add_relation(("qt_default", i))
        # add_relation('qt_default')
        # if qt_token_match:
        #    add_relation('qt_token_match')

        for i in range(self.num_utterance_keep):
            add_relation(("cq_default", i))
        # add_relation('cq_default')
        # if cq_token_match:
        #    add_relation('cq_token_match')

        for i in range(self.num_utterance_keep):
            add_relation(("tq_default", i))
        # add_relation('tq_default')
        # if cq_token_match:
        #    add_relation('tq_token_match')

        add_relation('cc_default')
        if cc_foreign_key:
            add_relation('cc_foreign_key_forward')
            add_relation('cc_foreign_key_backward')
        if cc_table_match:
            add_relation('cc_table_match')
        add_rel_dist('cc_dist', cc_max_dist)

        add_relation('ct_default')
        if ct_foreign_key:
            add_relation('ct_foreign_key')
        if ct_table_match:
            add_relation('ct_primary_key')
            add_relation('ct_table_match')
            add_relation('ct_any_table')

        add_relation('tc_default')
        if tc_table_match:
            add_relation('tc_primary_key')
            add_relation('tc_table_match')
            add_relation('tc_any_table')
        if tc_foreign_key:
            add_relation('tc_foreign_key')

        add_relation('tt_default')
        if tt_foreign_key:
            add_relation('tt_foreign_key_forward')
            add_relation('tt_foreign_key_backward')
            add_relation('tt_foreign_key_both')
        add_rel_dist('tt_dist', tt_max_dist)

        # schema linking relations
        # forward_backward
        if sc_link:
            for i in range(self.num_utterance_keep):
                add_relation(('qcCEM', i))
                add_relation(('cqCEM', i))
                add_relation(('qtTEM', i))
                add_relation(('tqTEM', i))
                add_relation(('qcCPM', i))
                add_relation(('cqCPM', i))
                add_relation(('qtTPM', i))
                add_relation(('tqTPM', i))

        if cv_link:
            for i in range(self.num_utterance_keep):
                add_relation(("qcNUMBER", i))
                add_relation(("cqNUMBER", i))
                add_relation(("qcTIME", i))
                add_relation(("cqTIME", i))
                add_relation(("qcCELLMATCH", i))
                add_relation(("cqCELLMATCH", i))

        if merge_types:                             #这个一定要设为false
            assert not cc_foreign_key
            assert not cc_table_match
            assert not ct_foreign_key
            assert not ct_table_match
            assert not tc_foreign_key
            assert not tc_table_match
            assert not tt_foreign_key

            assert cc_max_dist == qq_max_dist
            assert tt_max_dist == qq_max_dist

            add_relation('xx_default')
            self.relation_ids['qc_default'] = self.relation_ids['xx_default']
            self.relation_ids['qt_default'] = self.relation_ids['xx_default']
            self.relation_ids['cq_default'] = self.relation_ids['xx_default']
            self.relation_ids['cc_default'] = self.relation_ids['xx_default']
            self.relation_ids['ct_default'] = self.relation_ids['xx_default']
            self.relation_ids['tq_default'] = self.relation_ids['xx_default']
            self.relation_ids['tc_default'] = self.relation_ids['xx_default']
            self.relation_ids['tt_default'] = self.relation_ids['xx_default']

            if sc_link:
                self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
                self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
            if cv_link:
                self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
                self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
                self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default']
                self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default']

            for i in range(-qq_max_dist, qq_max_dist + 1):
                self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i]
                self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i]

        if ff_size is None:
            ff_size = hidden_size * 4
        self.encoder = transformer.Encoder(
            lambda: transformer.EncoderLayer(
                hidden_size,
                transformer.MultiHeadedAttentionWithRelations(
                    num_heads,
                    hidden_size,
                    dropout),
                transformer.PositionwiseFeedForward(
                    hidden_size,
                    ff_size,
                    dropout),
                len(self.relation_ids),
                dropout),
            hidden_size,
            num_layers,
            tie_layers) #FALSE

        self.align_attn = transformer.PointerWithRelations(hidden_size,
                                                           len(self.relation_ids), dropout)
Beispiel #3
0
    def __init__(
        self,
        device,
        num_layers,
        num_heads,
        hidden_size,
        ff_size=None,
        dropout=0.1,
        merge_types=False,
        tie_layers=False,
        qq_max_dist=2,
        # qc_token_match=True,
        # qt_token_match=True,
        # cq_token_match=True,
        cc_foreign_key=True,
        cc_table_match=True,
        cc_max_dist=2,
        ct_foreign_key=True,
        ct_table_match=True,
        # tq_token_match=True,
        tc_table_match=True,
        tc_foreign_key=True,
        tt_max_dist=2,
        tt_foreign_key=True,
        sc_link=False,
        cv_link=False,
    ):
        super().__init__()
        self._device = device
        self.num_heads = num_heads

        self.qq_max_dist = qq_max_dist
        # self.qc_token_match = qc_token_match
        # self.qt_token_match = qt_token_match
        # self.cq_token_match = cq_token_match
        self.cc_foreign_key = cc_foreign_key
        self.cc_table_match = cc_table_match
        self.cc_max_dist = cc_max_dist
        self.ct_foreign_key = ct_foreign_key
        self.ct_table_match = ct_table_match
        # self.tq_token_match = tq_token_match
        self.tc_table_match = tc_table_match
        self.tc_foreign_key = tc_foreign_key
        self.tt_max_dist = tt_max_dist
        self.tt_foreign_key = tt_foreign_key

        self.relation_ids = {}  # relation 2 id

        def add_relation(name):
            self.relation_ids[name] = len(self.relation_ids)

        def add_rel_dist(name, max_dist):
            for i in range(-max_dist, max_dist + 1):
                add_relation((name, i))

        add_rel_dist('qq_dist', qq_max_dist)

        add_relation('qc_default')  # question column
        # if qc_token_match:
        #    add_relation('qc_token_match')

        add_relation('qt_default')  # question table
        # if qt_token_match:
        #    add_relation('qt_token_match')

        add_relation('cq_default')  # column question
        # if cq_token_match:
        #    add_relation('cq_token_match')

        add_relation('cc_default')  # column column
        if cc_foreign_key:
            add_relation('cc_foreign_key_forward')  # column column
            add_relation('cc_foreign_key_backward')
        if cc_table_match:
            add_relation('cc_table_match')
        add_rel_dist('cc_dist', cc_max_dist)

        add_relation('ct_default')  # column table
        if ct_foreign_key:
            add_relation('ct_foreign_key')
        if ct_table_match:
            add_relation('ct_primary_key')
            add_relation('ct_table_match')
            add_relation('ct_any_table')

        add_relation('tq_default')  # table question
        # if cq_token_match:
        #    add_relation('tq_token_match')

        add_relation('tc_default')  # table column
        if tc_table_match:
            add_relation('tc_primary_key')
            add_relation('tc_table_match')
            add_relation('tc_any_table')
        if tc_foreign_key:
            add_relation('tc_foreign_key')

        add_relation('tt_default')  # table table
        if tt_foreign_key:
            add_relation('tt_foreign_key_forward')
            add_relation('tt_foreign_key_backward')
            add_relation('tt_foreign_key_both')
        add_rel_dist('tt_dist', tt_max_dist)

        # schema linking relations
        # forward_backward
        if sc_link:
            add_relation('qcCEM')
            add_relation('cqCEM')
            add_relation('qtTEM')
            add_relation('tqTEM')
            add_relation('qcCPM')
            add_relation('cqCPM')
            add_relation('qtTPM')
            add_relation('tqTPM')

        if cv_link:
            add_relation("qcNUMBER")
            add_relation("cqNUMBER")
            add_relation("qcTIME")
            add_relation("cqTIME")
            add_relation("qcCELLMATCH")
            add_relation("cqCELLMATCH")

        if merge_types:
            assert not cc_foreign_key
            assert not cc_table_match
            assert not ct_foreign_key
            assert not ct_table_match
            assert not tc_foreign_key
            assert not tc_table_match
            assert not tt_foreign_key

            assert cc_max_dist == qq_max_dist
            assert tt_max_dist == qq_max_dist

            add_relation('xx_default')
            self.relation_ids['qc_default'] = self.relation_ids['xx_default']
            self.relation_ids['qt_default'] = self.relation_ids['xx_default']
            self.relation_ids['cq_default'] = self.relation_ids['xx_default']
            self.relation_ids['cc_default'] = self.relation_ids['xx_default']
            self.relation_ids['ct_default'] = self.relation_ids['xx_default']
            self.relation_ids['tq_default'] = self.relation_ids['xx_default']
            self.relation_ids['tc_default'] = self.relation_ids['xx_default']
            self.relation_ids['tt_default'] = self.relation_ids['xx_default']

            if sc_link:
                self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
                self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
            if cv_link:
                self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
                self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
                self.relation_ids["qcCELLMATCH"] = self.relation_ids[
                    'xx_default']
                self.relation_ids["cqCELLMATCH"] = self.relation_ids[
                    'xx_default']

            for i in range(-qq_max_dist, qq_max_dist + 1):
                self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist',
                                                                    i]
                self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist',
                                                                    i]

        if ff_size is None:
            ff_size = hidden_size * 4
        self.encoder = transformer.Encoder(
            lambda: transformer.EncoderLayer(
                hidden_size,
                transformer.MultiHeadedAttentionWithRelations(
                    num_heads, hidden_size, dropout),
                transformer.PositionwiseFeedForward(hidden_size, ff_size,
                                                    dropout),
                len(self.relation_ids), dropout), hidden_size, num_layers,
            tie_layers)

        self.align_attn = transformer.PointerWithRelations(
            hidden_size, len(self.relation_ids), dropout)