Esempio n. 1
0
    def load_optimizer(self, config):
        with self.init_random:
            # 0. create inner_optimizer
            inner_parameters = self.model.get_trainable_parameters()
            inner_optimizer = registry.construct("optimizer",
                                                 self.train_config.inner_opt,
                                                 params=inner_parameters)

            # 1. MAML trainer, might add new parameters to the optimizer, e.g., step size
            maml_trainer = maml.MAML(
                model=self.model,
                inner_opt=inner_optimizer,
                device=self.device,
                first_order=self.train_config.first_order,
            )
            maml_trainer.to(self.device)

            opt_params = maml_trainer.get_inner_opt_params()
            self.logger.info(f"{len(opt_params)} opt meta parameters")

            # 2. Outer optimizer
            optimizer = registry.construct(
                "optimizer",
                config["optimizer"],
                params=itertools.chain(self.model.get_trainable_parameters(),
                                       opt_params),
            )

            lr_scheduler = registry.construct(
                "lr_scheduler",
                config.get("lr_scheduler", {"name": "noop"}),
                param_groups=optimizer.param_groups,
            )
            return inner_optimizer, maml_trainer, optimizer, lr_scheduler
Esempio n. 2
0
 def load_optimizer(self, config):
     with self.init_random:
         if self.train_config.use_bert_training:
             bert_params = self.model.get_bert_parameters()
             non_bert_params = self.model.get_non_bert_parameters()
             assert len(non_bert_params) + len(bert_params) == len(
                 list(self.model.parameters()))
             optimizer = registry.construct(
                 "optimizer",
                 config["optimizer"],
                 non_bert_params=non_bert_params,
                 bert_params=bert_params,
             )
             lr_scheduler = registry.construct(
                 "lr_scheduler",
                 config.get("lr_scheduler", {"name": "noop"}),
                 param_groups=[
                     optimizer.non_bert_param_group,
                     optimizer.bert_param_group,
                 ],
             )
         else:
             optimizer = registry.construct(
                 "optimizer",
                 config["optimizer"],
                 params=self.model.get_trainable_parameters(),
             )
             lr_scheduler = registry.construct(
                 "lr_scheduler",
                 config.get("lr_scheduler", {"name": "noop"}),
                 param_groups=optimizer.param_groups,
             )
         return optimizer, lr_scheduler
Esempio n. 3
0
    def __init__(self, preproc, device, encoder, decoder):
        super().__init__()
        self.preproc = preproc
        self.encoder = registry.construct(
            "encoder", encoder, device=device, preproc=preproc.enc_preproc
        )
        self.decoder = registry.construct(
            "decoder", decoder, device=device, preproc=preproc.dec_preproc
        )

        assert getattr(self.encoder, "batched")  # use batched enc by default
Esempio n. 4
0
 def load_optimizer(self, config):
     with self.init_random:
         optimizer = registry.construct(
             "optimizer",
             config["optimizer"],
             params=self.model.get_trainable_parameters(),
         )
         lr_scheduler = registry.construct(
             "lr_scheduler",
             config.get("lr_scheduler", {"name": "noop"}),
             param_groups=optimizer.param_groups,
         )
         return optimizer, lr_scheduler
    def preprocess_item(self, item, validation_info):
        tokens = self.tokenize(item.question)
        grammar = registry.construct("grammar",
                                     self.grammar_config,
                                     domain=item.domain)
        raw_properties, ref_properties = grammar.get_properties()
        raw_values, ref_values = grammar.get_values()
        schema_raw_relations = grammar.get_schema_relations()

        if item.domain in self.schema_cache:
            context = self.schema_cache[item.domain]
            processed_properties = context.schema["columns"]
            processed_values = context.schema["values"]
            schema_relations = context.compute_schema_relations()
        else:
            processed_properties = [self.tokenize(p) for p in raw_properties]
            processed_values = [self.tokenize(v) for v in raw_values]

            context = registry.construct(
                "context",
                self.context_config,
                schema={
                    "columns": processed_properties,
                    "values": processed_values,
                    "schema_relations": schema_raw_relations,
                },
            )
            self.schema_cache[item.domain] = context
            schema_relations = context.compute_schema_relations()

        sc_relations = (context.compute_schema_linking(tokens)
                        if self.compute_sc_link else {})
        cv_relations = (context.compute_cell_value_linking(tokens)
                        if self.compute_cv_link else {})
        for relation_name in itertools.chain(schema_relations.keys(),
                                             sc_relations.keys(),
                                             cv_relations.keys()):
            self.relations.add(relation_name)

        return {
            "db_id": item.domain,  # comply with data_scheduler
            "question": tokens,
            "raw_question": item.question,
            "columns": processed_properties,
            "values": processed_values,
            "ref_columns": ref_properties,
            "ref_values": ref_values,
            "schema_relations": schema_relations,
            "sc_relations": sc_relations,
            "cv_relations": cv_relations,
        }
Esempio n. 6
0
    def infer(self, model, output_path, args):
        output = open(output_path, "w")

        infer_func = registry.lookup("infer_method", args.method)
        with torch.no_grad():
            if args.mode == "infer":
                orig_data = registry.construct(
                    "dataset", self.config["data"][args.section])
                preproc_data = self.model_preproc.dataset(args.section)
                if args.limit:
                    sliced_orig_data = itertools.islice(orig_data, args.limit)
                    sliced_preproc_data = itertools.islice(
                        preproc_data, args.limit)
                else:
                    sliced_orig_data = orig_data
                    sliced_preproc_data = preproc_data
                assert len(orig_data) == len(preproc_data)
                self._inner_infer(
                    model,
                    infer_func,
                    args.beam_size,
                    sliced_orig_data,
                    sliced_preproc_data,
                    output,
                    args.debug,
                )
Esempio n. 7
0
    def load_model(self, logdir, step):
        """Load a model (identified by the config used for construction) and return it"""
        # 1. Construct model
        model = registry.construct(
            "model",
            self.config["model"],
            preproc=self.model_preproc,
            device=self.device,
            unused_keys=("decoder_preproc", "encoder_preproc"),
        )
        model.to(self.device)
        model.eval()
        model.visualize_flag = False

        # 2. Restore its parameters
        saver = saver_mod.Saver({"model": model})
        last_step = saver.restore(logdir,
                                  step=step,
                                  map_location=self.device,
                                  item_keys=["model"])
        if last_step == 0:  # which is fine fro pretrained model
            # print("Warning: infer on untrained model")
            raise CheckpointNotFoundError(
                f"Attempting to infer on untrained model, logdir {logdir}, step {step}"
            )
        return model
Esempio n. 8
0
    def __init__(
        self,
        grammar,
        save_path,
        min_freq=3,
        max_count=5000,
        use_seq_elem_rules=False,
        value_tokenizer=None,
    ):
        self.grammar = registry.construct("grammar", grammar)
        self.ast_wrapper = self.grammar.ast_wrapper

        # tokenizer for value prediction, lazy init
        self.value_tokenizer_config = value_tokenizer

        self.vocab_path = os.path.join(save_path, "dec_vocab.json")
        self.observed_productions_path = os.path.join(
            save_path, "observed_productions.json"
        )
        self.grammar_rules_path = os.path.join(save_path, "grammar_rules.json")
        self.data_dir = os.path.join(save_path, "dec")

        self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
        self.use_seq_elem_rules = use_seq_elem_rules

        self.items = collections.defaultdict(list)
        self.sum_type_constructors = collections.defaultdict(set)
        self.field_presence_infos = collections.defaultdict(set)
        self.seq_lengths = collections.defaultdict(set)
        self.primitive_types = set()

        self.vocab = None
        self.all_rules = None
        self.rules_mask = None
    def __init__(
        self,
        save_path,
        grammar,
        context,
        word_emb,
        min_freq=3,
        max_count=5000,
        sc_link=True,
        cv_link=True,
    ):
        self.data_dir = os.path.join(save_path, "enc")
        self.compute_sc_link = sc_link
        self.compute_cv_link = cv_link
        self.grammar_config = grammar
        self.context_config = context
        self.texts = collections.defaultdict(list)
        self.word_emb = registry.construct("word_emb", word_emb)

        # vocab
        self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
        self.vocab_path = os.path.join(self.data_dir, "vocab.json")
        self.vocab_word_freq_path = os.path.join(self.data_dir,
                                                 "word_freq.json")

        self.relations = set()
        self.schema_cache = {}
    def __init__(
        self,
        device,
        preproc,
        bert_token_type=False,
        bert_version="bert-base-uncased",
        summarize_header="avg",
        include_in_memory=("question", "column", "table"),
        rat_config={},
        linking_config={},
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc
        self.bert_token_type = bert_token_type
        self.base_enc_hidden_size = (1024 if "large" in bert_version else 768)
        self.include_in_memory = include_in_memory

        # ways to summarize header
        assert summarize_header in ["first", "avg"]
        self.summarize_header = summarize_header
        self.enc_hidden_size = self.base_enc_hidden_size

        # matching
        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            preproc=preproc,
            device=device,
        )

        # rat
        rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}
        self.rat_update = registry.instantiate(
            rat_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=preproc.relations2id,
            hidden_size=self.enc_hidden_size,
        )

        # aligner
        self.aligner = rat.AlignmentWithRAT(
            device=device,
            hidden_size=self.enc_hidden_size,
            relations2id=preproc.relations2id,
            enable_latent_relations=False,
        )

        if "electra" in bert_version:
            modelclass = ElectraModel
        elif "bert" in bert_version:
            modelclass = BertModel
        else:
            raise NotImplementedError
        self.bert_model = modelclass.from_pretrained(bert_version)
        self.tokenizer = self.preproc.tokenizer
Esempio n. 11
0
    def __init__(
        self,
        save_path,
        context,
        min_freq=3,
        max_count=5000,
        include_table_name_in_column=True,
        word_emb=None,
        count_tokens_in_word_emb_for_vocab=False,
        compute_sc_link=False,
        compute_cv_link=False,
        use_ch_vocab=False,
        ch_word_emb=None,
    ):
        if word_emb is None:
            self.word_emb = None
        else:
            self.word_emb = registry.construct("word_emb", word_emb)

        self.data_dir = os.path.join(save_path, "enc")
        self.include_table_name_in_column = include_table_name_in_column
        self.count_tokens_in_word_emb_for_vocab = count_tokens_in_word_emb_for_vocab
        self.compute_sc_link = compute_sc_link
        self.compute_cv_link = compute_cv_link
        self.context_config = context

        self.texts = collections.defaultdict(list)
        self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
        self.vocab_path = os.path.join(save_path, "enc_vocab.json")
        self.vocab_word_freq_path = os.path.join(save_path,
                                                 "enc_word_freq.json")
        self.vocab = None
        self.use_ch_vocab = use_ch_vocab
        if use_ch_vocab:
            assert ch_word_emb is not None
            self.ch_word_emb = registry.construct("word_emb", ch_word_emb)
            self.ch_vocab_builder = vocab.VocabBuilder(min_freq, max_count)
            self.ch_vocab_path = os.path.join(save_path, "ch_enc_vocab.json")
            self.ch_vocab_word_freq_path = os.path.join(
                save_path, "ch_enc_word_freq.json")
            self.ch_vocab = None
        self.counted_db_ids = set()
        self.relations = set()

        self.context_cache = {}
Esempio n. 12
0
 def preprocess(self):
     self.model_preproc.clear_items()
     for section in self.config["data"]:
         data = registry.construct("dataset", self.config["data"][section])
         for item in tqdm.tqdm(data, desc=section, dynamic_ncols=True):
             to_add, validation_info = self.model_preproc.validate_item(
                 item, section)
             if to_add:
                 self.model_preproc.add_item(item, section, validation_info)
     self.model_preproc.save()
Esempio n. 13
0
 def load_train_data(self):
     with self.data_random:
         train_data = self.model_preproc.dataset("train")
         train_data_scheduler = registry.construct(
             "data_scheduler",
             self.train_config.data_scheduler,
             examples=train_data,
             max_train_step=self.train_config.max_steps,
         )
     return train_data_scheduler
Esempio n. 14
0
    def __init__(self, save_path, word_emb=None, min_freq=0, max_count=10000):
        self.save_path = save_path
        self.data_dir = os.path.join(save_path, "enc")
        self.texts = collections.defaultdict(list)
        self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
        self.vocab_path = os.path.join(save_path, "enc_vocab")

        # pretrained embeddings, e.g., Glove
        if word_emb is None:
            self.embedder = word_emb
        else:
            self.embedder = registry.construct("word_emb", word_emb)
        self.learnable_words = None
Esempio n. 15
0
def compute_metrics(config_path,
                    config_args,
                    section,
                    inferred_path,
                    etype,
                    logdir=None):
    if config_args:
        config = json.loads(
            _jsonnet.evaluate_file(config_path,
                                   tla_codes={"args": config_args}))
    else:
        config = json.loads(_jsonnet.evaluate_file(config_path))

    # update config to wandb
    wandb.config.update(config)

    if "model_name" in config and logdir:
        logdir = os.path.join(logdir, config["model_name"])
    if logdir:
        inferred_path = inferred_path.replace("__LOGDIR__", logdir)

    inferred = open(inferred_path)
    data = registry.construct("dataset", config["data"][section])
    metrics = data.Metrics(data, etype)

    inferred_lines = list(inferred)
    if len(inferred_lines) < len(data):
        raise Exception("Not enough inferred: {} vs {}".format(
            len(inferred_lines), len(data)))

    for i, line in enumerate(inferred_lines):
        infer_results = json.loads(line)
        if infer_results["beams"]:
            inferred_codes = [
                beam["inferred_code"] for beam in infer_results["beams"]
            ]
        else:
            inferred_codes = [None]
        assert "index" in infer_results

        if etype in ["execution", "all"]:
            # if eval by execution, then we choose the first executable one from the beams
            metrics.add_beams(data[infer_results["index"]], inferred_codes,
                              data[i].orig["question"])
        else:
            assert etype in ["match", "sacreBLEU", "tokenizedBLEU"]
            metrics.add_one(data[infer_results["index"]], inferred_codes[0])
    return logdir, metrics.finalize()
Esempio n. 16
0
    def load_model(self, config):
        with self.init_random:
            # 0. Construct preprocessors
            self.model_preproc = registry.instantiate(
                registry.lookup("model", config["model"]).Preproc,
                config["model"],
                unused_keys=("name", ),
            )
            self.model_preproc.load()

            # 1. Construct model
            self.model = registry.construct(
                "model",
                config["model"],
                unused_keys=("encoder_preproc", "decoder_preproc"),
                preproc=self.model_preproc,
                device=self.device,
            )
            self.model.to(self.device)
Esempio n. 17
0
    def preprocess_item(self, item, validation_info):
        if self.use_ch_vocab:
            question, question_for_copying = self._ch_tokenize_for_copying(
                item.text, item.orig["question"])
        else:
            question, question_for_copying = self._tokenize_for_copying(
                item.text, item.orig["question"])

        if item.schema.db_id in self.context_cache:
            context = self.context_cache[item.schema.db_id]
        else:
            context = registry.construct(
                "context",
                self.context_config,
                schema=item.schema,
                word_emb=self.word_emb,
            )
            self.context_cache[item.schema.db_id] = context

        preproc_schema = context.preproc_schema
        schema_relations = context.compute_schema_relations()
        sc_relations = (context.compute_schema_linking(question)
                        if self.compute_sc_link else {})
        cv_relations = (context.compute_schema_linking(question)
                        if self.compute_cv_link else {})

        return {
            "raw_question": item.orig["question"],
            "question": question,
            "question_for_copying": question_for_copying,
            "db_id": item.schema.db_id,
            "schema_relations": schema_relations,
            "sc_relations": sc_relations,
            "cv_relations": cv_relations,
            "columns": preproc_schema.column_names,
            "tables": preproc_schema.table_names,
            "table_bounds": preproc_schema.table_bounds,
            "column_to_table": preproc_schema.column_to_table,
            "table_to_columns": preproc_schema.table_to_columns,
            "foreign_keys": preproc_schema.foreign_keys,
            "foreign_keys_tables": preproc_schema.foreign_keys_tables,
            "primary_keys": preproc_schema.primary_keys,
        }
    def preprocess_item(self, item, validation_info):
        q_text = " ".join(item.text)

        # use the original words for copying, while they are not necessarily used for encoding
        # question_for_copying = self.tokenizer.tokenize_and_lemmatize(q_text)
        question_for_copying = self.tokenizer.tokenize_with_orig(q_text)

        if item.schema.db_id in self.context_cache:
            context = self.context_cache[item.schema.db_id]
        else:
            context = registry.construct(
                "context",
                self.context_config,
                schema=item.schema,
                tokenizer=self.tokenizer,
            )
            self.context_cache[item.schema.db_id] = context

        preproc_schema = context.preproc_schema
        schema_relations = context.compute_schema_relations()
        sc_relations = (context.compute_schema_linking(q_text)
                        if self.compute_sc_link else {})
        cv_relations = (context.compute_cell_value_linking(q_text)
                        if self.compute_cv_link else {})

        return {
            "question_text": q_text,
            "question_for_copying": question_for_copying,
            "db_id": item.schema.db_id,
            "schema_relations": schema_relations,
            "sc_relations": sc_relations,
            "cv_relations": cv_relations,
            "columns": preproc_schema.column_names,
            "tables": preproc_schema.table_names,
            "table_bounds": preproc_schema.table_bounds,
            "column_to_table": preproc_schema.column_to_table,
            "table_to_columns": preproc_schema.table_to_columns,
            "foreign_keys": preproc_schema.foreign_keys,
            "foreign_keys_tables": preproc_schema.foreign_keys_tables,
            "primary_keys": preproc_schema.primary_keys,
        }
Esempio n. 19
0
    def load_train_data(self):
        with self.data_random:
            train_data = self.model_preproc.dataset("train")
            syn_train_data = self.model_preproc.dataset("syn_train")
            self.logger.info(
                f"Load {len(train_data)} orig examples and {len(syn_train_data)} synthetic examples"
            )

            assert not self.train_config.use_kd_train
            assert not self.train_config.check_syn_consistency

            train_data_scheduler = registry.construct(
                "syn_data_scheduler",
                self.train_config.data_scheduler,
                examples=train_data,
                syn_examples=syn_train_data,
                batch_size=self.train_config.batch_size,
                warm_up_steps=self.config["lr_scheduler"]["num_warmup_steps"],
                decay_steps=self.config["train"]["max_steps"] -
                2 * self.config["lr_scheduler"]["num_warmup_steps"],
            )
        return train_data_scheduler
Esempio n. 20
0
    def preprocess_item(self, item, section, validation_info):
        grammar = registry.construct("grammar",
                                     self.grammar_config,
                                     domain=item.domain)
        norm_lf = grammar.normalize_lf(item.lf)
        actions = grammar.logical_form_to_action_sequence(norm_lf)

        # for convenice, should be adapted so that it doesn't depend on data
        d_t_rules_dict = grammar.get_domain_terminal_productions()
        for pt in d_t_rules_dict:
            self.domain_prod_dict[item.domain][pt] = self.domain_prod_dict[
                item.domain][pt].union(d_t_rules_dict[pt])

        # those are the pre-defined rules, does not need to induce from train data
        p_rules_dict = grammar.get_non_terminal_productions()
        t_rules_dict = grammar.get_general_terminal_productions(
        )  # treat as labels
        for pt in p_rules_dict:
            self.prod_dict[pt] = self.prod_dict[pt].union(p_rules_dict[pt])
        for pt in t_rules_dict:
            self.prod_dict[pt] = self.prod_dict[pt].union(t_rules_dict[pt])

        return {"domain": item.domain, "productions": actions}
Esempio n. 21
0
    def infer(self, model, output_path, args):
        output = open(output_path, "w")
        chunk_size = 128  # this is manually set, TODO add it to config

        assert args.method.startswith("batched")
        infer_func = registry.lookup("infer_method", args.method)
        with torch.no_grad():
            orig_data = registry.construct("dataset",
                                           self.config["data"][args.section])
            preproc_data = self.model_preproc.dataset(args.section)
            assert len(orig_data) == len(preproc_data)
            chunked_orig_data = chunks(orig_data, chunk_size)
            chunked_preproc_data = chunks(preproc_data, chunk_size)
            pbar = tqdm.tqdm(total=len(preproc_data))
            self._inner_infer(
                model,
                infer_func,
                args.beam_size,
                chunked_orig_data,
                chunked_preproc_data,
                output,
                pbar,
            )
            pbar.close()
Esempio n. 22
0
    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.0,
            question_encoder=("emb", "bilstm"),
            column_encoder=("emb", "bilstm"),
            table_encoder=("emb", "bilstm"),
            linking_config={},
            rat_config={},
            top_k_learnable=0,
            include_in_memory=("question", "column", "table"),
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc

        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        word_freq = self.preproc.vocab_builder.word_freq
        top_k_words = set(
            [_a[0] for _a in word_freq.most_common(top_k_learnable)])
        self.learnable_words = top_k_words
        self.include_in_memory = set(include_in_memory)
        self.dropout = dropout

        shared_modules = {
            "shared-en-emb":
            embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                self.learnable_words,
            ),
            "shared-bilstm":
            lstm.BiLSTM(
                input_size=self.word_emb_size,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                summarize=False,
            ),
        }

        # chinese vocab and module
        if self.preproc.use_ch_vocab:
            self.ch_vocab = preproc.ch_vocab
            ch_word_freq = self.preproc.ch_vocab_builder.word_freq
            ch_top_k_words = set(
                [_a[0] for _a in ch_word_freq.most_common(top_k_learnable)])
            self.ch_learnable_words = ch_top_k_words
            shared_modules["shared-ch-emb"] = embedders.LookupEmbeddings(
                self._device,
                self.ch_vocab,
                self.preproc.ch_word_emb,
                self.preproc.ch_word_emb.dim,
                self.ch_learnable_words,
            )
            shared_modules["ch-bilstm"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=False,
                summarize=False,
            )
            shared_modules["ch-bilstm-native"] = lstm.BiLSTM(
                input_size=self.preproc.ch_word_emb.dim,
                output_size=self.recurrent_size,
                dropout=self.dropout,
                use_native=True,
                summarize=False,
            )

        self.question_encoder = self._build_modules(
            question_encoder, shared_modules=shared_modules)
        self.column_encoder = self._build_modules(
            column_encoder, shared_modules=shared_modules)
        self.table_encoder = self._build_modules(table_encoder,
                                                 shared_modules=shared_modules)

        # matching
        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            device=device,
            word_emb_size=word_emb_size,
            preproc=preproc,
        )

        # rat
        rat_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}
        self.rat_update = registry.instantiate(
            rat_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=preproc.relations2id,
            hidden_size=recurrent_size,
        )

        # aligner
        self.aligner = rat.AlignmentWithRAT(
            device=device,
            hidden_size=recurrent_size,
            relations2id=preproc.relations2id,
            enable_latent_relations=rat_config["enable_latent_relations"],
            num_latent_relations=rat_config.get("num_latent_relations", None),
            combine_latent_relations=rat_config["combine_latent_relations"],
        )
Esempio n. 23
0
    def load_optimizer(self, config):
        with self.init_random:
            # 0. create inner_optimizer
            # inner_parameters = list(self.model.get_trainable_parameters())
            inner_parameters = list(self.model.get_non_bert_parameters())
            inner_optimizer = registry.construct("optimizer",
                                                 self.train_config.inner_opt,
                                                 params=inner_parameters)
            self.logger.info(
                f"{len(inner_parameters)} parameters for inner update")

            # 1. MAML trainer, might add new parameters to the optimizer, e.g., step size
            maml_trainer = maml.MAML(
                model=self.model,
                inner_opt=inner_optimizer,
                device=self.device,
            )
            maml_trainer.to(self.device)

            opt_params = maml_trainer.get_inner_opt_params()
            self.logger.info(f"{len(opt_params)} opt meta parameters")

            # 2. Outer optimizer
            # if config["optimizer"].get("name", None) in ["bertAdamw", "torchAdamw"]:
            if self.train_config.use_bert_training:
                bert_params = self.model.get_bert_parameters()
                non_bert_params = self.model.get_non_bert_parameters()
                assert len(non_bert_params) + len(bert_params) == len(
                    list(self.model.parameters()))
                assert len(bert_params) > 0
                self.logger.info(
                    f"{len(bert_params)} BERT parameters and {len(non_bert_params)} non-BERT parameters"
                )

                optimizer = registry.construct(
                    "optimizer",
                    config["optimizer"],
                    non_bert_params=non_bert_params,
                    bert_params=bert_params,
                )
                lr_scheduler = registry.construct(
                    "lr_scheduler",
                    config.get("lr_scheduler", {"name": "noop"}),
                    param_groups=[
                        optimizer.non_bert_param_group,
                        optimizer.bert_param_group,
                    ],
                )
            else:
                optimizer = registry.construct(
                    "optimizer",
                    config["optimizer"],
                    params=self.model.get_trainable_parameters(),
                )
                lr_scheduler = registry.construct(
                    "lr_scheduler",
                    config.get("lr_scheduler", {"name": "noop"}),
                    param_groups=optimizer.param_groups,
                )

            lr_scheduler = registry.construct(
                "lr_scheduler",
                config.get("lr_scheduler", {"name": "noop"}),
                param_groups=optimizer.param_groups,
            )
            return inner_optimizer, maml_trainer, optimizer, lr_scheduler
Esempio n. 24
0
    def __init__(
            self,
            device,
            preproc,
            word_emb_size=128,
            recurrent_size=256,
            dropout=0.0,
            question_encoder=("emb", "bilstm"),
            column_encoder=("emb", "bilstm"),
            value_encoder=("emb", "bilstm"),
            linking_config={},
            rat_config={},
            top_k_learnable=0,
            include_in_memory=("question", ),
    ):
        super().__init__()
        self._device = device
        self.preproc = preproc

        self.vocab = preproc.vocab
        self.word_emb_size = word_emb_size
        self.recurrent_size = recurrent_size
        assert self.recurrent_size % 2 == 0
        word_freq = self.preproc.vocab_builder.word_freq
        top_k_words = set(
            [_a[0] for _a in word_freq.most_common(top_k_learnable)])
        self.learnable_words = top_k_words
        self.include_in_memory = set(include_in_memory)
        self.dropout = dropout

        shared_modules = {
            "shared-en-emb":
            embedders.LookupEmbeddings(
                self._device,
                self.vocab,
                self.preproc.word_emb,
                self.word_emb_size,
                self.learnable_words,
            )
        }

        self.question_encoder = self._build_modules(
            question_encoder, "question", shared_modules=shared_modules)
        self.column_encoder = self._build_modules(
            column_encoder, "column", shared_modules=shared_modules)
        self.value_encoder = self._build_modules(value_encoder,
                                                 "value",
                                                 shared_modules=shared_modules)

        update_modules = {"rat": rat.RAT, "none": rat.NoOpUpdate}

        self.schema_linking = registry.construct(
            "schema_linking",
            linking_config,
            device=device,
            word_emb_size=word_emb_size,
            preproc=preproc,
        )

        self.rat_update = registry.instantiate(
            update_modules[rat_config["name"]],
            rat_config,
            unused_keys={"name"},
            device=self._device,
            relations2id=self.preproc.relations2id,
            hidden_size=recurrent_size,
        )