Пример #1
0
def main(args):
    # Initialize Emmental
    config = parse_args_to_config(args)
    emmental.init(log_dir=config["meta_config"]["log_path"], config=config)

    # Log configuration into files
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg)

    logger.info(f"Config: {emmental.Meta.config}")
    write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config)

    # Create dataloaders
    dataloaders = get_dataloaders(args)

    # Assign transforms to dataloaders
    aug_dataloaders = []
    if args.augment_policy:
        for idx in range(len(dataloaders)):
            if dataloaders[idx].split in args.train_split:
                dataloaders[idx].dataset.transform_cls = Augmentation(
                    args=args)

    config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = AugScheduler(augment_k=args.augment_k,
                                         enlarge=args.augment_enlarge)
    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = config["learner_config"]["task_scheduler_config"][
            "task_scheduler"]

    # Create tasks
    model = EmmentalModel(name=f"{args.task}_task")
    model.add_task(create_task(args))

    # Set cudnn benchmark
    cudnn.benchmark = True

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders + aug_dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None

    scores = model.score(dataloaders)

    # Save metrics and models
    logger.info(f"Metrics: {scores}")
    scores["log_path"] = emmental.Meta.log_path
    write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
    model.save(f"{emmental.Meta.log_path}/last_model.pth")
Пример #2
0
    def load_best_model(self, model: EmmentalModel) -> EmmentalModel:
        """Load the best model from the checkpoint.

        Args:
          model: The current model.

        Returns:
          The best model load from the checkpoint.
        """
        if list(self.checkpoint_metric.keys())[0] not in self.best_metric_dict:
            logger.info("No best model found, use the original model.")
        else:
            # Load the best model of checkpoint_metric
            metric = list(self.checkpoint_metric.keys())[0]
            best_model_path = (f"{self.checkpoint_path}/best_model_"
                               f"{metric.replace('/', '_')}.model.pth")
            model.load(best_model_path, verbose=False)
            logger.info(f"Loading the best model from {best_model_path}.")

        return model
Пример #3
0
def predict_and_write(task_name, path, data_dir, submit_subdir, batch_size):
    bert_model_name, max_seq_len = extract_from_cmd(path)
    msg = (
        f"Using {bert_model_name} and max_sequence_len={max_seq_len} for task "
        f"{task_name}")
    logger.info(msg)

    # Build model
    task = build_model[task_name](bert_model_name)
    model = EmmentalModel(name=f"SuperGLUE_{task_name}", tasks=[task])
    try:
        model.load(path)
    except UnboundLocalError:
        msg = (
            "Failed to load state dict; confirm that your model was saved with "
            "a command such as 'torch.save(model.state_dict(), PATH)'")
        logging.error(msg)
        raise

    # Build dataloaders
    dataloaders = get_dataloaders(
        data_dir,
        task_name=task_name,
        splits=["val",
                "test"],  # TODO: replace with ['split'] and update below
        max_data_samples=None,
        max_sequence_length=max_seq_len,
        tokenizer_name=bert_model_name,
        batch_size=batch_size,
        uid="uids",
    )
    # TEMP: Sanity check val performance
    logger.info(f"Valid score: {model.score(dataloaders[0])}")
    # TEMP

    filename = f"{task_name}.jsonl"
    filepath = os.path.join(submit_subdir, filename)
    make_submission_file(model, dataloaders[-1], task_name, filepath)
Пример #4
0
def main(args):
    # Ensure that global state is fresh
    Meta.reset()

    # Initialize Emmental
    config = parse_arg_to_config(args)
    emmental.init(config["meta_config"]["log_path"], config=config)

    # Save command line argument into file
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(Meta.log_path, "cmd.txt", cmd_msg)

    # Save Emmental config into file
    logger.info(f"Config: {Meta.config}")
    write_to_file(Meta.log_path, "config.txt", Meta.config)

    Meta.config["learner_config"]["global_evaluation_metric_dict"] = {
        f"model/SuperGLUE/{split}/score": partial(superglue_scorer,
                                                  split=split)
        for split in ["val"]
    }

    # Construct dataloaders and tasks and load slices
    dataloaders = []
    tasks = []

    for task_name in args.task:
        task_dataloaders = get_dataloaders(
            data_dir=args.data_dir,
            task_name=task_name,
            splits=["train", "val", "test"],
            max_sequence_length=args.max_sequence_length,
            max_data_samples=args.max_data_samples,
            tokenizer_name=args.bert_model,
            batch_size=args.batch_size,
            augment=args.augmentations,
        )
        task = models.model[task_name](
            args.bert_model,
            last_hidden_dropout_prob=args.last_hidden_dropout_prob)
        if args.slices:
            logger.info("Initializing task-specific slices")
            slice_func_dict = slicing.slice_func_dict[task_name]
            # Include general purpose slices
            if args.general_slices:
                logger.info("Including general slices")
                slice_func_dict.update(slicing.slice_func_dict["general"])

            task_dataloaders = slicing.add_slice_labels(
                task_name, task_dataloaders, slice_func_dict)

            slice_tasks = slicing.add_slice_tasks(task_name, task,
                                                  slice_func_dict,
                                                  args.slice_hidden_dim)
            tasks.extend(slice_tasks)
        else:
            tasks.append(task)

        dataloaders.extend(task_dataloaders)

    # Build Emmental model
    model = EmmentalModel(name="SuperGLUE", tasks=tasks)

    # Load pretrained model if necessary
    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])

    # Training
    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # If model is slice-aware, slice scores will be calculated from slice heads
    # If model is not slice-aware, manually calculate performance on slices
    if not args.slices:
        slice_func_dict = {}
        slice_keys = args.task
        if args.general_slices:
            slice_keys.append("general")

        for k in slice_keys:
            slice_func_dict.update(slicing.slice_func_dict[k])

        scores = slicing.score_slices(model, dataloaders, args.task,
                                      slice_func_dict)
    else:
        scores = model.score(dataloaders)

    # Save metrics into file
    logger.info(f"Metrics: {scores}")
    write_to_file(Meta.log_path, "metrics.txt", scores)

    # Save best metrics into file
    if args.train:
        logger.info(
            f"Best metrics: "
            f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
        )
        write_to_file(
            Meta.log_path,
            "best_metrics.txt",
            emmental_learner.logging_manager.checkpointer.best_metric_dict,
        )

    # Save submission file
    for task_name in args.task:
        dataloaders = [d for d in dataloaders if d.split == "test"]
        assert len(dataloaders) == 1
        filepath = os.path.join(Meta.log_path, f"{task_name}.jsonl")
        make_submission_file(model, dataloaders[0], task_name, filepath)
Пример #5
0
def main(args):
    # Initialize Emmental
    config = parse_args_to_config(args)
    emmental.init(log_dir=config["meta_config"]["log_path"], config=config)

    # Log configuration into files
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg)

    logger.info(f"Config: {emmental.Meta.config}")
    write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config)

    # Create dataloaders
    dataloaders = get_dataloaders(args)

    config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = AugScheduler(augment_k=args.augment_k,
                                         enlarge=args.augment_enlarge)
    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = config["learner_config"]["task_scheduler_config"][
            "task_scheduler"]

    # Specify parameter group for Adam BERT
    def grouped_parameters(model):
        no_decay = ["bias", "LayerNorm.weight"]
        return [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                emmental.Meta.config["learner_config"]["optimizer_config"]
                ["l2"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create tasks
    model = EmmentalModel(name=f"{args.task}_task")
    model.add_task(create_task(args))

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None
        dataloaders[idx].dataset.k = 1

    scores = model.score(dataloaders)

    # Save metrics and models
    logger.info(f"Metrics: {scores}")
    scores["log_path"] = emmental.Meta.log_path
    write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
    model.save(f"{emmental.Meta.log_path}/last_model.pth")
Пример #6
0
    for task_name in args.task:
        dataloaders += create_dataloaders(
            task_name, datasets[task_name], args.batch_size, emb_layer.word2id
        )

    tasks = {
        task_name: create_task(
            task_name, args, datasets[task_name]["nclasses"], emb_layer
        )
        for task_name in args.task
    }

    model = EmmentalModel(name="TC_task")

    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])
    else:
        for task_name, task in tasks.items():
            model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)
    logger.info(f"Metrics: {scores}")
    write_to_json_file(f"{Meta.log_path}/metrics.txt", scores)

    if args.checkpointing:
        logger.info(
            f"Best metrics: "
            f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
Пример #7
0
def test_model(caplog):
    """Unit test of model."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_model"

    Meta.reset()
    emmental.init(dirpath)

    def ce_loss(module_name, immediate_output_dict, Y, active):
        return F.cross_entropy(immediate_output_dict[module_name][0][active],
                               (Y.view(-1))[active])

    def output(module_name, immediate_output_dict):
        return F.softmax(immediate_output_dict[module_name][0], dim=1)

    task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 10, bias=False),
            "m2": nn.Linear(10, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    task2 = EmmentalTask(
        name="task_2",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    config = {"model_config": {"dataparallel": False}}
    emmental.Meta.update_config(config)

    model = EmmentalModel(name="test", tasks=task1)

    assert repr(model) == "EmmentalModel(name=test)"
    assert model.name == "test"
    assert model.task_names == set(["task_1"])
    assert model.module_pool["m1"].weight.data.size() == (10, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 10)

    model.update_task(new_task1)

    assert model.module_pool["m1"].weight.data.size() == (5, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 5)

    model.update_task(task2)

    assert model.task_names == set(["task_1"])

    model.add_task(task2)

    assert model.task_names == set(["task_1", "task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.save(f"{dirpath}/saved_model.pth")

    model.load(f"{dirpath}/saved_model.pth")

    # Test add_tasks
    model = EmmentalModel(name="test")

    model.add_tasks([task1, task2])
    assert model.task_names == set(["task_1", "task_2"])

    shutil.rmtree(dirpath)
Пример #8
0
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create tasks
    model = EmmentalModel(name="TACRED_task")
    model.add_task(create_task(args))

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None

    scores = model.score(dataloaders)

    # Save metrics and models
    logger.info(f"Metrics: {scores}")
    scores["log_path"] = emmental.Meta.log_path
    write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
Пример #9
0
class BootlegAnnotator(object):
    """BootlegAnnotator class: convenient wrapper of preprocessing and model
    eval to allow for annotating single sentences at a time for quick
    experimentation, e.g. in notebooks.

    Args:
        config: model config (default None)
        device: model device, -1 for CPU (default None)
        max_alias_len: maximum alias length (default 6)
        cand_map: alias candidate map (default None)
        threshold: probability threshold (default 0.0)
        cache_dir: cache directory (default None)
        model_name: model name (default None)
        verbose: verbose boolean (default False)
    """
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)

    def extract_mentions(self, text, label_func):
        """Wrapper function for mention extraction.

        Args:
            text: text to extract mentions from
            label_func: function that performs extraction (input is (text, alias trie, max alias length) ->
                        output is list of found aliases and found spans

        Returns: JSON object of sentence to be used in eval
        """
        found_aliases, found_spans = label_func(text, self.all_aliases_trie,
                                                self.max_alias_len)
        return {
            "sentence": text,
            "aliases": found_aliases,
            "spans": found_spans,
            # we don't know the true QID
            "qids": ["Q-1" for i in range(len(found_aliases))],
            "gold": [True for i in range(len(found_aliases))],
        }

    def set_threshold(self, value):
        """Sets threshold.

        Args:
            value: threshold value

        Returns:
        """
        self.threshold = value

    def label_mentions(self,
                       text_list,
                       label_func=find_aliases_in_sentence_tag):
        """Extracts mentions and runs disambiguation.

        Args:
            text_list: list of text to disambiguate (or single sentence)
            label_func: mention extraction funciton (optional)

        Returns: Dict of

            * ``qids``: final predicted QIDs,
            * ``probs``: final predicted probs,
            * ``titles``: final predicted titles,
            * ``cands``: all entity canddiates,
            * ``cand_probs``: probabilities of all candidates,
            * ``spans``: final extracted word spans,
            * ``aliases``: final extracted aliases,
        """
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert (type(text_list) is list and len(text_list) > 0
                    and type(text_list[0]) is str
                    ), f"We only accept inputs of strings and lists of strings"

        ebs = int(self.config.run_config.eval_batch_size)
        self.config.data_config.max_aliases = int(
            self.config.data_config.max_aliases)
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(
                enumerate(text_list),
                desc="Prepping data",
                total=len(text_list),
                disable=not self.verbose,
        ):
            sample = self.extract_mentions(text, label_func)
            total_start_exs += len(sample["aliases"])
            char_spans = self.get_char_spans(sample["spans"], text)

            final_char_spans.append(char_spans)

            (
                idxs_arr,
                aliases_to_predict_per_split,
                spans_arr,
                phrase_tokens_arr,
                pos_idxs,
            ) = sentence_utils.split_sentence(
                max_aliases=self.config.data_config.max_aliases,
                phrase=sample["sentence"],
                spans=sample["spans"],
                aliases=sample["aliases"],
                aliases_seen_by_model=list(range(len(sample["aliases"]))),
                seq_len=self.config.data_config.max_seq_len,
                is_bert=True,
                tokenizer=self.tokenizer,
            )
            aliases_arr = [[sample["aliases"][idx] for idx in idxs]
                           for idxs in idxs_arr]
            old_spans_arr = [[sample["spans"][idx] for idx in idxs]
                             for idxs in idxs_arr]
            qids_arr = [[sample["qids"][idx] for idx in idxs]
                        for idxs in idxs_arr]
            word_indices_arr = [
                self.tokenizer.convert_tokens_to_ids(pt)
                for pt in phrase_tokens_arr
            ]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert (
                    len(aliases_to_predict_arr) >= 0
                ), f"There are no aliases to predict for an example. This should not happen at this point."
                assert (
                    len(aliases_arr[sub_idx]) <=
                    self.config.data_config.max_aliases
                ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases."

                example_aliases = np.ones(
                    self.config.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_aliases_locs_end = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_alias_list_pos = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_true_entities = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][
                        mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.config.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one
                        # (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is
                        # reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(
                            alias_qids == qids_arr[sub_idx][mention_idx]
                        )[0][0] + (
                            not self.config.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][
                        mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence
                    # and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(
                    example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(
            batch_example_aliases_locs_start, device=self.torch_device)
        batch_example_aliases_locs_end = torch.tensor(
            batch_example_aliases_locs_end, device=self.torch_device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities,
                                                   device=self.torch_device)
        batch_word_indices = torch.tensor(batch_word_indices,
                                          device=self.torch_device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(
                range(0, batch_example_aliases.shape[0], ebs),
                desc="Evaluating model",
                disable=not self.verbose,
        ):
            start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs]
            end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            x_dict = self.get_forward_batch(start_span_idx, end_span_idx,
                                            word_indices, alias_indices)
            x_dict["guid"] = torch.arange(b_i,
                                          b_i + ebs,
                                          device=self.torch_device)

            (uid_bdict, _, prob_bdict, _) = self.model(  # type: ignore
                uids=x_dict["guid"],
                X_dict=x_dict,
                Y_dict=None,
                task_to_label_dict=self.task_to_label_dict,
                return_action_outputs=False,
            )
            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================
            # recover predictions
            probs = prob_bdict[NED_TASK]
            max_probs = probs.max(2)
            max_probs_indices = probs.argmax(2)
            for ex_i in range(probs.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(
                    self.config.data_config.train_in_candidates,
                    self.config.data_config.max_aliases,
                    self.entity_db.get_alias2qids(),
                    batch_aliases_arr[b_i + ex_i],
                )
                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].reshape(
                    self.config.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(
                        batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(
                                probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(
                                batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(
                                batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(
                                self.entity_db.get_title(pred_qid)
                                if pred_qid != "NC" else "NC")
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, (
            f"Something went wrong and we have predicted fewer mentions than extracted. "
            f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        )
        res_dict = {
            "qids": final_pred_cands,
            "probs": final_pred_probs,
            "titles": final_titles,
            "cands": final_all_cands,
            "cand_probs": final_cand_probs,
            "spans": final_spans,
            "aliases": final_aliases,
        }
        return res_dict

    def get_forward_batch(self, start_span_idx, end_span_idx, token_ids,
                          alias_idx):
        """Preps the forward batch for disambiguation.

        Args:
            start_span_idx: start span tensor
            end_span_idx: end span tensor
            token_ids: word token tensor
            alias_idx: alias index used for extracting candidate eids

        Returns: X_dict used in Emmental
        """
        entity_cand_eid = self.alias2cands(alias_idx).long()
        entity_cand_eid_mask = entity_cand_eid == -1
        entity_cand_eid_noneg = torch.where(
            entity_cand_eid >= 0,
            entity_cand_eid,
            (torch.ones_like(entity_cand_eid, dtype=torch.long) *
             (self.entity_db.num_entities_with_pad_and_nocand - 1)),
        )

        kg_prepped_embs = {}
        for emb_key in self.batch_on_the_fly_embs:
            kg_adj = self.batch_on_the_fly_embs[emb_key]["kg_adj"]
            prep_func = self.batch_on_the_fly_embs[emb_key][
                "kg_adj_process_func"]
            batch_prep = []
            for j in range(entity_cand_eid_noneg.shape[0]):
                batch_prep.append(
                    prep_func(entity_cand_eid_noneg[j].cpu(),
                              kg_adj).reshape(1, -1))
            kg_prepped_embs[emb_key] = torch.tensor(batch_prep,
                                                    device=self.torch_device)

        X_dict = {
            "guids": [],
            "start_span_idx": start_span_idx,
            "end_span_idx": end_span_idx,
            "token_ids": token_ids,
            "entity_cand_eid": entity_cand_eid_noneg,
            "entity_cand_eid_mask": entity_cand_eid_mask,
            "batch_on_the_fly_kg_adj": kg_prepped_embs,
        }
        return X_dict

    def get_char_spans(self, spans, text):
        """Helper function to get character spans instead of default word
        spans.

        Args:
            spans: word spans
            text: text

        Returns: character spans
        """
        query_toks = text.split()
        char_spans = []
        for span in spans:
            space_btwn_toks = (len(" ".join(query_toks[0:span[0] + 1])) -
                               len(" ".join(query_toks[0:span[0]])) -
                               len(query_toks[span[0]]))
            char_b = len(" ".join(query_toks[0:span[0]])) + space_btwn_toks
            char_e = char_b + len(" ".join(query_toks[span[0]:span[1]]))
            char_spans.append([char_b, char_e])
        return char_spans
Пример #10
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file