示例#1
0
    def load_train_config(self):
        self.train_config = registry.instantiate(MetaTrainConfig,
                                                 self.config["meta_train"])

        if self.train_config.num_batch_accumulated > 1:
            self.logger.warn(
                "Batch accumulation is used only at MAML-step level")
            raise NotImplementedError
示例#2
0
    def load_train_config(self):
        self.train_config = registry.instantiate(TrainConfig,
                                                 self.config["train"])

        if self.train_config.use_bert_training:
            if self.train_config.clip_grad is None:
                self.logger.info(
                    "Grad clipping is recommended for BERT training")
    def __init__(
        self,
        device,
        preproc,
        bert_token_type=False,
        bert_version="bert-base-uncased",
        summarize_header="avg",
        include_in_memory=("question", "column", "table"),
        rat_config={},
        linking_config={},
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc
        self.bert_token_type = bert_token_type
        self.base_enc_hidden_size = (1024 if "large" in bert_version else 768)
        self.include_in_memory = include_in_memory

        # ways to summarize header
        assert summarize_header in ["first", "avg"]
        self.summarize_header = summarize_header
        self.enc_hidden_size = self.base_enc_hidden_size

        # matching
        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            preproc=preproc,
            device=device,
        )

        # rat
        rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}
        self.rat_update = registry.instantiate(
            rat_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=preproc.relations2id,
            hidden_size=self.enc_hidden_size,
        )

        # aligner
        self.aligner = rat.AlignmentWithRAT(
            device=device,
            hidden_size=self.enc_hidden_size,
            relations2id=preproc.relations2id,
            enable_latent_relations=False,
        )

        if "electra" in bert_version:
            modelclass = ElectraModel
        elif "bert" in bert_version:
            modelclass = BertModel
        else:
            raise NotImplementedError
        self.bert_model = modelclass.from_pretrained(bert_version)
        self.tokenizer = self.preproc.tokenizer
示例#4
0
    def load_train_config(self):
        self.train_config = registry.instantiate(MetaTrainConfig,
                                                 self.config["meta_train"])

        if self.train_config.num_batch_accumulated > 1:
            self.logger.warn(
                "Batch accumulation is used only at MAML-step level")

        if self.train_config.use_bert_training:
            if self.train_config.clip_grad is None:
                self.logger.info(
                    "Gradient clipping is recommended for BERT training")
示例#5
0
    def __init__(self, config):
        self.config = config
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")
            torch.set_num_threads(1)

        # 0. Construct preprocessors
        self.model_preproc = registry.instantiate(
            registry.lookup("model", config["model"]).Preproc,
            config["model"],
            unused_keys=("name", ),
        )
        self.model_preproc.load()
示例#6
0
    def load_model(self, config):
        with self.init_random:
            # 0. Construct preprocessors
            self.model_preproc = registry.instantiate(
                registry.lookup("model", config["model"]).Preproc,
                config["model"],
                unused_keys=("name", ),
            )
            self.model_preproc.load()

            # 1. Construct model
            self.model = registry.construct(
                "model",
                config["model"],
                unused_keys=("encoder_preproc", "decoder_preproc"),
                preproc=self.model_preproc,
                device=self.device,
            )
            self.model.to(self.device)
示例#7
0
 def load_train_config(self):
     self.train_config = registry.instantiate(TrainConfig,
                                              self.config["train"])
示例#8
0
    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.0,
            question_encoder=("emb", "bilstm"),
            column_encoder=("emb", "bilstm"),
            table_encoder=("emb", "bilstm"),
            linking_config={},
            rat_config={},
            top_k_learnable=0,
            include_in_memory=("question", "column", "table"),
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc

        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        word_freq = self.preproc.vocab_builder.word_freq
        top_k_words = set(
            [_a[0] for _a in word_freq.most_common(top_k_learnable)])
        self.learnable_words = top_k_words
        self.include_in_memory = set(include_in_memory)
        self.dropout = dropout

        shared_modules = {
            "shared-en-emb":
            embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                self.learnable_words,
            ),
            "shared-bilstm":
            lstm.BiLSTM(
                input_size=self.word_emb_size,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                summarize=False,
            ),
        }

        # chinese vocab and module
        if self.preproc.use_ch_vocab:
            self.ch_vocab = preproc.ch_vocab
            ch_word_freq = self.preproc.ch_vocab_builder.word_freq
            ch_top_k_words = set(
                [_a[0] for _a in ch_word_freq.most_common(top_k_learnable)])
            self.ch_learnable_words = ch_top_k_words
            shared_modules["shared-ch-emb"] = embedders.LookupEmbeddings(
                self._device,
                self.ch_vocab,
                self.preproc.ch_word_emb,
                self.preproc.ch_word_emb.dim,
                self.ch_learnable_words,
            )
            shared_modules["ch-bilstm"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=False,
                summarize=False,
            )
            shared_modules["ch-bilstm-native"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=True,
                summarize=False,
            )

        self.question_encoder = self._build_modules(
            question_encoder, shared_modules=shared_modules)
        self.column_encoder = self._build_modules(
            column_encoder, shared_modules=shared_modules)
        self.table_encoder = self._build_modules(table_encoder,
                                                 shared_modules=shared_modules)

        # matching
        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            device=device,
            word_emb_size=word_emb_size,
            preproc=preproc,
        )

        # rat
        rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}
        self.rat_update = registry.instantiate(
            rat_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=preproc.relations2id,
            hidden_size=recurrent_size,
        )

        # aligner
        self.aligner = rat.AlignmentWithRAT(
            device=device,
            hidden_size=recurrent_size,
            relations2id=preproc.relations2id,
            enable_latent_relations=rat_config["enable_latent_relations"],
            num_latent_relations=rat_config.get("num_latent_relations", None),
            combine_latent_relations=rat_config["combine_latent_relations"],
        )
    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.0,
            question_encoder=("emb", "bilstm"),
            column_encoder=("emb", "bilstm"),
            value_encoder=("emb", "bilstm"),
            linking_config={},
            rat_config={},
            top_k_learnable=0,
            include_in_memory=("question", ),
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc

        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        word_freq = self.preproc.vocab_builder.word_freq
        top_k_words = set(
            [_a[0] for _a in word_freq.most_common(top_k_learnable)])
        self.learnable_words = top_k_words
        self.include_in_memory = set(include_in_memory)
        self.dropout = dropout

        shared_modules = {
            "shared-en-emb":
            embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                self.learnable_words,
            )
        }

        self.question_encoder = self._build_modules(
            question_encoder, "question", shared_modules=shared_modules)
        self.column_encoder = self._build_modules(
            column_encoder, "column", shared_modules=shared_modules)
        self.value_encoder = self._build_modules(value_encoder,
                                                 "value",
                                                 shared_modules=shared_modules)

        update_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}

        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            device=device,
            word_emb_size=word_emb_size,
            preproc=preproc,
        )

        self.rat_update = registry.instantiate(
            update_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=self.preproc.relations2id,
            hidden_size=recurrent_size,
        )
示例#10
0
 def __init__(self, config):
     self.config = config
     self.model_preproc = registry.instantiate(
         registry.lookup("model", config["model"]).Preproc, config["model"])