def load_train_config(self): self.train_config = registry.instantiate(MetaTrainConfig, self.config["meta_train"]) if self.train_config.num_batch_accumulated > 1: self.logger.warn( "Batch accumulation is used only at MAML-step level") raise NotImplementedError
def load_train_config(self): self.train_config = registry.instantiate(TrainConfig, self.config["train"]) if self.train_config.use_bert_training: if self.train_config.clip_grad is None: self.logger.info( "Grad clipping is recommended for BERT training")
def __init__( self, device, preproc, bert_token_type=False, bert_version="bert-base-uncased", summarize_header="avg", include_in_memory=("question", "column", "table"), rat_config={}, linking_config={}, ): super().__init__() self._device = device self.preproc = preproc self.bert_token_type = bert_token_type self.base_enc_hidden_size = (1024 if "large" in bert_version else 768) self.include_in_memory = include_in_memory # ways to summarize header assert summarize_header in ["first", "avg"] self.summarize_header = summarize_header self.enc_hidden_size = self.base_enc_hidden_size # matching self.schema_linking = registry.construct( "schema_linking", linking_config, preproc=preproc, device=device, ) # rat rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate} self.rat_update = registry.instantiate( rat_modules[rat_config["name"]], rat_config, unused_keys={"name"}, device=self._device, relations2id=preproc.relations2id, hidden_size=self.enc_hidden_size, ) # aligner self.aligner = rat.AlignmentWithRAT( device=device, hidden_size=self.enc_hidden_size, relations2id=preproc.relations2id, enable_latent_relations=False, ) if "electra" in bert_version: modelclass = ElectraModel elif "bert" in bert_version: modelclass = BertModel else: raise NotImplementedError self.bert_model = modelclass.from_pretrained(bert_version) self.tokenizer = self.preproc.tokenizer
def load_train_config(self): self.train_config = registry.instantiate(MetaTrainConfig, self.config["meta_train"]) if self.train_config.num_batch_accumulated > 1: self.logger.warn( "Batch accumulation is used only at MAML-step level") if self.train_config.use_bert_training: if self.train_config.clip_grad is None: self.logger.info( "Gradient clipping is recommended for BERT training")
def __init__(self, config): self.config = config if torch.cuda.is_available(): self.device = torch.device("cuda:0") else: self.device = torch.device("cpu") torch.set_num_threads(1) # 0. Construct preprocessors self.model_preproc = registry.instantiate( registry.lookup("model", config["model"]).Preproc, config["model"], unused_keys=("name", ), ) self.model_preproc.load()
def load_model(self, config): with self.init_random: # 0. Construct preprocessors self.model_preproc = registry.instantiate( registry.lookup("model", config["model"]).Preproc, config["model"], unused_keys=("name", ), ) self.model_preproc.load() # 1. Construct model self.model = registry.construct( "model", config["model"], unused_keys=("encoder_preproc", "decoder_preproc"), preproc=self.model_preproc, device=self.device, ) self.model.to(self.device)
def load_train_config(self): self.train_config = registry.instantiate(TrainConfig, self.config["train"])
def __init__( self, device, preproc, word_emb_size=128, recurrent_size=256, dropout=0.0, question_encoder=("emb", "bilstm"), column_encoder=("emb", "bilstm"), table_encoder=("emb", "bilstm"), linking_config={}, rat_config={}, top_k_learnable=0, include_in_memory=("question", "column", "table"), ): super().__init__() self._device = device self.preproc = preproc self.vocab = preproc.vocab self.word_emb_size = word_emb_size self.recurrent_size = recurrent_size assert self.recurrent_size % 2 == 0 word_freq = self.preproc.vocab_builder.word_freq top_k_words = set( [_a[0] for _a in word_freq.most_common(top_k_learnable)]) self.learnable_words = top_k_words self.include_in_memory = set(include_in_memory) self.dropout = dropout shared_modules = { "shared-en-emb": embedders.LookupEmbeddings( self._device, self.vocab, self.preproc.word_emb, self.word_emb_size, self.learnable_words, ), "shared-bilstm": lstm.BiLSTM( input_size=self.word_emb_size, output_size=self.recurrent_size, dropout=self.dropout, summarize=False, ), } # chinese vocab and module if self.preproc.use_ch_vocab: self.ch_vocab = preproc.ch_vocab ch_word_freq = self.preproc.ch_vocab_builder.word_freq ch_top_k_words = set( [_a[0] for _a in ch_word_freq.most_common(top_k_learnable)]) self.ch_learnable_words = ch_top_k_words shared_modules["shared-ch-emb"] = embedders.LookupEmbeddings( self._device, self.ch_vocab, self.preproc.ch_word_emb, self.preproc.ch_word_emb.dim, self.ch_learnable_words, ) shared_modules["ch-bilstm"] = lstm.BiLSTM( input_size=self.preproc.ch_word_emb.dim, output_size=self.recurrent_size, dropout=self.dropout, use_native=False, summarize=False, ) shared_modules["ch-bilstm-native"] = lstm.BiLSTM( input_size=self.preproc.ch_word_emb.dim, output_size=self.recurrent_size, dropout=self.dropout, use_native=True, summarize=False, ) self.question_encoder = self._build_modules( question_encoder, shared_modules=shared_modules) self.column_encoder = self._build_modules( column_encoder, shared_modules=shared_modules) self.table_encoder = self._build_modules(table_encoder, shared_modules=shared_modules) # matching self.schema_linking = registry.construct( "schema_linking", linking_config, device=device, word_emb_size=word_emb_size, preproc=preproc, ) # rat rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate} self.rat_update = registry.instantiate( rat_modules[rat_config["name"]], rat_config, unused_keys={"name"}, device=self._device, relations2id=preproc.relations2id, hidden_size=recurrent_size, ) # aligner self.aligner = rat.AlignmentWithRAT( device=device, hidden_size=recurrent_size, relations2id=preproc.relations2id, enable_latent_relations=rat_config["enable_latent_relations"], num_latent_relations=rat_config.get("num_latent_relations", None), combine_latent_relations=rat_config["combine_latent_relations"], )
def __init__( self, device, preproc, word_emb_size=128, recurrent_size=256, dropout=0.0, question_encoder=("emb", "bilstm"), column_encoder=("emb", "bilstm"), value_encoder=("emb", "bilstm"), linking_config={}, rat_config={}, top_k_learnable=0, include_in_memory=("question", ), ): super().__init__() self._device = device self.preproc = preproc self.vocab = preproc.vocab self.word_emb_size = word_emb_size self.recurrent_size = recurrent_size assert self.recurrent_size % 2 == 0 word_freq = self.preproc.vocab_builder.word_freq top_k_words = set( [_a[0] for _a in word_freq.most_common(top_k_learnable)]) self.learnable_words = top_k_words self.include_in_memory = set(include_in_memory) self.dropout = dropout shared_modules = { "shared-en-emb": embedders.LookupEmbeddings( self._device, self.vocab, self.preproc.word_emb, self.word_emb_size, self.learnable_words, ) } self.question_encoder = self._build_modules( question_encoder, "question", shared_modules=shared_modules) self.column_encoder = self._build_modules( column_encoder, "column", shared_modules=shared_modules) self.value_encoder = self._build_modules(value_encoder, "value", shared_modules=shared_modules) update_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate} self.schema_linking = registry.construct( "schema_linking", linking_config, device=device, word_emb_size=word_emb_size, preproc=preproc, ) self.rat_update = registry.instantiate( update_modules[rat_config["name"]], rat_config, unused_keys={"name"}, device=self._device, relations2id=self.preproc.relations2id, hidden_size=recurrent_size, )
def __init__(self, config): self.config = config self.model_preproc = registry.instantiate( registry.lookup("model", config["model"]).Preproc, config["model"])