Example #1
0
 def test_fine_tune_nograd_regex(self):
     original_model = load_archive(self.model_archive).model
     name_parameters_original = dict(original_model.named_parameters())
     regex_lists = [[],
                    [".*attend_feedforward.*", ".*token_embedder.*"],
                    [".*compare_feedforward.*"]]
     for regex_list in regex_lists:
         params = Params.from_file(self.config_file)
         params["trainer"]["no_grad"] = regex_list
         shutil.rmtree(self.serialization_dir, ignore_errors=True)
         tuned_model = fine_tune_model(model=original_model,
                                       params=params,
                                       serialization_dir=self.serialization_dir)
         # If regex is matched, parameter name should have requires_grad False
         # If regex is matched, parameter name should have same requires_grad
         # as the originally loaded model
         for name, parameter in tuned_model.named_parameters():
             if any(re.search(regex, name) for regex in regex_list):
                 assert not parameter.requires_grad
             else:
                 assert parameter.requires_grad \
                 == name_parameters_original[name].requires_grad
     # If all parameters have requires_grad=False, then error.
     with pytest.raises(Exception) as _:
         params = Params.from_file(self.config_file)
         params["trainer"]["no_grad"] = ["*"]
         shutil.rmtree(self.serialization_dir, ignore_errors=True)
         tuned_model = fine_tune_model(model=original_model,
                                       params=params,
                                       serialization_dir=self.serialization_dir)
Example #2
0
 def test_fine_tune_nograd_regex(self):
     original_model = load_archive(self.model_archive).model
     name_parameters_original = dict(original_model.named_parameters())
     regex_lists = [[], [".*attend_feedforward.*", ".*token_embedder.*"],
                    [".*compare_feedforward.*"]]
     for regex_list in regex_lists:
         params = Params.from_file(self.config_file)
         params["trainer"]["no_grad"] = regex_list
         shutil.rmtree(self.serialization_dir, ignore_errors=True)
         tuned_model = fine_tune_model(
             model=original_model,
             params=params,
             serialization_dir=self.serialization_dir)
         # If regex is matched, parameter name should have requires_grad False
         # If regex is matched, parameter name should have same requires_grad
         # as the originally loaded model
         for name, parameter in tuned_model.named_parameters():
             if any(re.search(regex, name) for regex in regex_list):
                 assert not parameter.requires_grad
             else:
                 assert parameter.requires_grad \
                 == name_parameters_original[name].requires_grad
     # If all parameters have requires_grad=False, then error.
     with pytest.raises(Exception) as _:
         params = Params.from_file(self.config_file)
         params["trainer"]["no_grad"] = ["*"]
         shutil.rmtree(self.serialization_dir, ignore_errors=True)
         tuned_model = fine_tune_model(
             model=original_model,
             params=params,
             serialization_dir=self.serialization_dir)
Example #3
0
    def test_fine_tune_runtime_errors_with_vocab_expansion(self):
        params = Params.from_file(self.config_file)
        params[u"train_data_path"] = unicode(self.FIXTURES_ROOT / u'data' / u'snli2.jsonl')

        model = load_archive(self.model_archive).model

        # If we do vocab expansion, we get a runtime error because of the embedding.
        with pytest.raises(RuntimeError):
            fine_tune_model(model, params, self.serialization_dir, extend_vocab=True)
Example #4
0
    def test_fine_tune_does_not_expand_vocab_by_default(self):
        params = Params.from_file(self.config_file)
        # snli2 has a new token in it
        params[u"train_data_path"] = unicode(self.FIXTURES_ROOT / u'data' / u'snli2.jsonl')

        model = load_archive(self.model_archive).model

        # By default, no vocab expansion.
        fine_tune_model(model, params, self.serialization_dir)
Example #5
0
    def test_fine_tune_runtime_errors_with_vocab_expansion(self):
        params = Params.from_file(self.config_file)
        params["train_data_path"] = str(self.FIXTURES_ROOT / 'data' / 'snli2.jsonl')

        model = load_archive(self.model_archive).model

        # If we do vocab expansion, we get a runtime error because of the embedding.
        with pytest.raises(RuntimeError):
            fine_tune_model(model, params, self.serialization_dir, extend_vocab=True)
Example #6
0
    def test_fine_tune_does_not_expand_vocab_by_default(self):
        params = Params.from_file(self.config_file)
        # snli2 has a new token in it
        params["train_data_path"] = str(self.FIXTURES_ROOT / 'data' / 'snli2.jsonl')

        model = load_archive(self.model_archive).model

        # By default, no vocab expansion.
        fine_tune_model(model, params, self.serialization_dir)
def grid_search(model_archive_path: str,
                config_file: str,
                serialization_dir: str,
                file_friendly_logging: bool = False,
                step: float = 0.2):
    thresholds = []
    f1_m_scores = []
    for span_threshold in [0.25, 0.5, 0.75]:
        for true_threshold in [0.3, 0.5, 0.7]:
            for false_threshold in np.arange(0.3, true_threshold + step, step):
                thresholds.append(
                    (span_threshold, true_threshold, false_threshold))
                logger.info("#" * 100)
                logger.info("-" * 100)
                logger.info(
                    "The current setting is (span_threshold={}, true_threshold={}, false_threshold={})"
                    .format(span_threshold, true_threshold, false_threshold))
                logger.info("-" * 100)

                # Delete the serialization directory, if exists.
                if os.path.exists(serialization_dir) and os.listdir(
                        serialization_dir):
                    shutil.rmtree(serialization_dir, ignore_errors=True)

                # Train the MultiBiDAF model with the current settings.
                threshold_params = {
                    "model": {
                        "span_threshold": span_threshold,
                        "true_threshold": true_threshold,
                        "false_threshold": false_threshold
                    }
                }
                archive = load_archive(model_archive_path,
                                       overrides=str(threshold_params))

                params = Params.from_file(config_file)
                fine_tune_model(archive.model, params, serialization_dir,
                                file_friendly_logging)

                # Add the best validation F1_m score of this run to f1_m_scores
                with open(os.path.join(serialization_dir,
                                       "metrics.json")) as f:
                    metrics = json.load(f)
                f1_m_scores.append(metrics["best_validation_f1_m"])
                logger.info(pformat(metrics))

    # Find the best setting
    max_f1_m = max(f1_m_scores)
    argmax = f1_m_scores.index(max(f1_m_scores))
    best_setting = thresholds[argmax]

    logger.info("*" * 100)
    logger.info(
        "The best setting is (span_threshold={}, true_threshold={}, false_threshold={}) with F1_m={}"
        .format(best_setting[0], best_setting[1], best_setting[2], max_f1_m))
    logger.info("*" * 100)
Example #8
0
 def test_fine_tune_extended_model_is_loadable(self):
     params = Params.from_file(self.config_file)
     # snli2 has a new token (seahorse) in it
     params["train_data_path"] = str(self.FIXTURES_ROOT / 'data' / 'snli2.jsonl')
     trained_model = load_archive(self.model_archive).model
     shutil.rmtree(self.serialization_dir, ignore_errors=True)
     fine_tune_model(trained_model, params.duplicate(),
                     self.serialization_dir, extend_vocab=True)
     # self.serialization_dir = str(self.TEST_DIR / 'fine_tune')
     load_archive(str(self.TEST_DIR / 'fine_tune' / "model.tar.gz"))
Example #9
0
 def check_embedding_extension(user_pretrained_file,
                               saved_pretrained_file, use_pretrained):
     trained_model = load_archive(self.model_archive).model
     original_weight = trained_model._text_field_embedder.token_embedder_tokens.weight
     # Simulate the behavior of unavailable pretrained_file being stored as an attribute.
     trained_model._text_field_embedder.token_embedder_tokens._pretrained_file = saved_pretrained_file
     embedding_sources_mapping = {
         "_text_field_embedder.token_embedder_tokens":
         user_pretrained_file
     }
     shutil.rmtree(self.serialization_dir, ignore_errors=True)
     fine_tuned_model = fine_tune_model(
         trained_model,
         params.duplicate(),
         self.serialization_dir,
         extend_vocab=True,
         embedding_sources_mapping=embedding_sources_mapping)
     extended_weight = fine_tuned_model._text_field_embedder.token_embedder_tokens.weight
     assert original_weight.shape[0] + 1 == extended_weight.shape[
         0] == 25
     assert torch.all(original_weight == extended_weight[:24, :])
     if use_pretrained:
         assert torch.all(extended_weight[24, :] == extra_token_vector)
     else:
         assert torch.all(extended_weight[24, :] != extra_token_vector)
Example #10
0
def fine_tune(**params):
    param_is_exist(["config_file", "serialization_dir", "include_package", "model_file"], params)
    for package_name in params["include_package"]:
        import_submodules(package_name)
    overrides = params["overrides"] if "overrides" in params else ""
    recover = params["recover"] if "recover" in params else ""
    force = params["force"] if "force" in params else ""
    config_params = Params.from_file(params["config_file"], overrides)
    archive = load_archive(params["model_file"])
    return fine_tune_model(archive.model, config_params, params["serialization_dir"], recover, force)
Example #11
0
    def test_fine_tune_works_with_vocab_expansion(self):
        params = Params.from_file(self.config_file)
        # snli2 has a new token in it
        params["train_data_path"] = str(self.FIXTURES_ROOT / 'data' / 'snli2.jsonl')

        trained_model = load_archive(self.model_archive).model
        original_weight = trained_model._text_field_embedder.token_embedder_tokens.weight

        # If we do vocab expansion, we should not get error now.
        fine_tuned_model = fine_tune_model(trained_model, params, self.serialization_dir, extend_vocab=True)
        extended_weight = fine_tuned_model._text_field_embedder.token_embedder_tokens.weight

        assert tuple(original_weight.shape) == (24, 300)
        assert tuple(extended_weight.shape) == (25, 300)
        assert torch.all(original_weight == extended_weight[:24, :])