Beispiel #1
0
    def _build_modules(self, module_types, shared_modules=None):
        module_builder = {
            "en-emb":
            lambda: embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                learnable_words=self.en_learnable_words,
            ),
            "bilstm":
            lambda: lstm.BiLSTM(
                input_size=self.word_emb_size,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                summarize=False,
                use_native=False,
            ),
            "bilstm-native":
            lambda: lstm.BiLSTM(
                input_size=self.word_emb_size,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                summarize=False,
                use_native=True,
            ),
        }

        modules = []
        for module_type in module_types:
            if module_type in shared_modules:
                modules.append(shared_modules[module_type])
            else:
                modules.append(module_builder[module_type]())
        return torch.nn.Sequential(*modules)
Beispiel #2
0
    def _build_modules(self, module_types):
        module_builder = {
            "emb":
            lambda: embedders.LookupEmbeddings(
                device=self._device,
                vocab=self.vocab,
                embedder=self.preproc.embedder,
                emb_size=self.word_emb_size,
                learnable_words=self.preproc.learnable_words,
            ),
            "bilstm":
            lambda: lstm.BiLSTM(
                input_size=self.word_emb_size,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=True,
                summarize=False,
            ),
            "cls_glue":
            lambda: rat.PadCLS(
                device=self._device,
                hidden_size=self.recurrent_size,
                pos_encode=False,
            ),
            "cls_glue_p":
            lambda: rat.PadCLS(
                device=self._device,
                hidden_size=self.recurrent_size,
                pos_encode=True,
            ),
            "transformer":
            lambda: rat.TransformerEncoder(
                device=self._device,
                num_layers=1,
                num_heads=4,
                hidden_size=self.recurrent_size,
            ),
        }

        modules = []
        for module_type in module_types:
            modules.append(module_builder[module_type]())
        return torch.nn.Sequential(*modules)
Beispiel #3
0
    def __init__(
            self,
            device,
            preproc,
            word_emb_size,
            num_latent_relations,
            hidden_size=300,
            recurrent_size=256,
            discrete_relation=True,
            norm_relation=True,
            symmetric_relation=False,
            combine_latent_relations=False,
            score_type="bilinear",
            learnable_embeddings=False,
            question_encoder=("shared-en-emb", ),
            column_encoder=("shared-en-emb", ),
            table_encoder=("shared-en-emb", ),
    ):
        super().__init__()
        self.preproc = preproc
        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self._device = device
        self.hidden_size = hidden_size
        self.discrete_relation = discrete_relation
        self.norm_relation = norm_relation
        self.num_latent_relations = num_latent_relations
        self.relations2id = preproc.relations2id
        self.recurrent_size = recurrent_size
        self.dropout = 0.0

        score_funcs = {
            "bilinear":
            lambda: energys.Bilinear(
                hidden_size, num_latent_relations, include_id=True),
            "mlp":
            lambda: energys.MLP(hidden_size, num_latent_relations),
        }

        # build modules
        if learnable_embeddings:
            self.en_learnable_words = self.vocab
        else:
            self.en_learnable_words = None
        shared_modules = {
            "shared-en-emb":
            embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                learnable_words=self.en_learnable_words,
            ),
        }

        if self.preproc.use_ch_vocab:
            self.ch_vocab = preproc.ch_vocab
            if learnable_embeddings:
                self.ch_learnable_words = self.ch_vocab
            else:
                self.ch_learnable_words = None
            shared_modules["shared-ch-emb"] = embedders.LookupEmbeddings(
                self._device,
                self.ch_vocab,
                self.preproc.ch_word_emb,
                self.preproc.ch_word_emb.dim,
                learnable_words=self.ch_learnable_words,
            )
            shared_modules["ch-bilstm"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=False,
                summarize=False,
            )
            shared_modules["ch-bilstm-native"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=True,
                summarize=False,
            )

        self.question_encoder = self._build_modules(
            question_encoder, shared_modules=shared_modules)
        self.column_encoder = self._build_modules(
            column_encoder, shared_modules=shared_modules)
        self.table_encoder = self._build_modules(table_encoder,
                                                 shared_modules=shared_modules)

        self.combine_latent_relations = combine_latent_relations
        if combine_latent_relations:
            self.string_link = StringLinking(device, preproc)

        self.symmetric_relation = symmetric_relation
        assert self.symmetric_relation
        if self.symmetric_relation:
            relations = ("qc", "qt")
        else:
            relations = ("qc", "cq", "tq", "qt")
        self.relation_score_dic = nn.ModuleDict(
            {k: score_funcs[score_type]()
             for k in relations})

        if discrete_relation:
            self.temperature = 1  # for gumbel

        if not norm_relation:  # then norm q/col/tab
            self.null_q_token = nn.Parameter(torch.zeros([1, hidden_size]))
            self.null_c_token = nn.Parameter(torch.zeros([1, hidden_size]))
            self.null_t_token = nn.Parameter(torch.zeros([1, hidden_size]))
Beispiel #4
0
    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.0,
            question_encoder=("emb", "bilstm"),
            column_encoder=("emb", "bilstm"),
            table_encoder=("emb", "bilstm"),
            linking_config={},
            rat_config={},
            top_k_learnable=0,
            include_in_memory=("question", "column", "table"),
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc

        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        word_freq = self.preproc.vocab_builder.word_freq
        top_k_words = set(
            [_a[0] for _a in word_freq.most_common(top_k_learnable)])
        self.learnable_words = top_k_words
        self.include_in_memory = set(include_in_memory)
        self.dropout = dropout

        shared_modules = {
            "shared-en-emb":
            embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                self.learnable_words,
            ),
            "shared-bilstm":
            lstm.BiLSTM(
                input_size=self.word_emb_size,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                summarize=False,
            ),
        }

        # chinese vocab and module
        if self.preproc.use_ch_vocab:
            self.ch_vocab = preproc.ch_vocab
            ch_word_freq = self.preproc.ch_vocab_builder.word_freq
            ch_top_k_words = set(
                [_a[0] for _a in ch_word_freq.most_common(top_k_learnable)])
            self.ch_learnable_words = ch_top_k_words
            shared_modules["shared-ch-emb"] = embedders.LookupEmbeddings(
                self._device,
                self.ch_vocab,
                self.preproc.ch_word_emb,
                self.preproc.ch_word_emb.dim,
                self.ch_learnable_words,
            )
            shared_modules["ch-bilstm"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=False,
                summarize=False,
            )
            shared_modules["ch-bilstm-native"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=True,
                summarize=False,
            )

        self.question_encoder = self._build_modules(
            question_encoder, shared_modules=shared_modules)
        self.column_encoder = self._build_modules(
            column_encoder, shared_modules=shared_modules)
        self.table_encoder = self._build_modules(table_encoder,
                                                 shared_modules=shared_modules)

        # matching
        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            device=device,
            word_emb_size=word_emb_size,
            preproc=preproc,
        )

        # rat
        rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}
        self.rat_update = registry.instantiate(
            rat_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=preproc.relations2id,
            hidden_size=recurrent_size,
        )

        # aligner
        self.aligner = rat.AlignmentWithRAT(
            device=device,
            hidden_size=recurrent_size,
            relations2id=preproc.relations2id,
            enable_latent_relations=rat_config["enable_latent_relations"],
            num_latent_relations=rat_config.get("num_latent_relations", None),
            combine_latent_relations=rat_config["combine_latent_relations"],
        )