예제 #1
0
    def load_regularization_mapping(cls, data_config, entity_symbols,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            qid2topk_eid: Dict from QID to eid in the entity embedding
            num_entities_with_pad_and_nocand: number of entities including pad and null candidate option
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"entity_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            eid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_info(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            qid2reg = pd.read_csv(reg_file)
            assert (
                "qid" in qid2reg.columns
                and "regularization" in qid2reg.columns
            ), f"Expected qid and regularization as the column names for {reg_file}"
            # default of no mask
            eid2reg_arr = [0.0
                           ] * entity_symbols.num_entities_with_pad_and_nocand
            for row_idx, row in qid2reg.iterrows():
                if entity_symbols.qid_exists(row["qid"]):
                    eid = entity_symbols.get_eid(row["qid"])
                    eid2reg_arr[eid] = row["regularization"]
            eid2reg = torch.tensor(eid2reg_arr)
            torch.save(eid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return eid2reg
예제 #2
0
    def unfreeze_params(self):
        """Unfreezes the parameters of the module.

        Returns:
        """
        for name, param in self.named_parameters():
            param.requires_grad = True
            log_rank_0_info(logger, f"Unfreezing {name}")
        return
예제 #3
0
    def prep(
        cls,
        data_config,
        emb_args,
        entity_symbols,
        threshold,
        log_weight,
    ):
        """Preps the KG information.

        Args:
            data_config: data config
            emb_args: embedding args
            entity_symbols: entity symbols
            threshold: weight threshold for counting an edge
            log_weight: whether to take the log of the weight value after the threshold

        Returns: numpy sparce KG adjacency matrix, prep file
        """
        file_tag = os.path.splitext(emb_args.kg_adj.replace("/", "_"))[0]
        prep_dir = data_utils.get_emb_prep_dir(data_config)
        prep_file = os.path.join(prep_dir, f"kg_adj_file_{file_tag}.npz")
        utils.ensure_dir(os.path.dirname(prep_file))
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(logger,
                             f"Loading existing KG adj from {prep_file}")
            start = time.time()
            kg_adj = scipy.sparse.load_npz(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing KG adj in {round(time.time() - start, 2)}s")
        else:
            start = time.time()
            kg_adj_file = os.path.join(data_config.emb_dir, emb_args.kg_adj)
            log_rank_0_info(logger, f"Building KG adj from {kg_adj_file}")
            kg_adj = cls.build_kg_adj(kg_adj_file, entity_symbols, threshold,
                                      log_weight)
            scipy.sparse.save_npz(prep_file, kg_adj)
            log_rank_0_debug(
                logger,
                f"Finished building and saving KG adj in {round(time.time() - start, 2)}s.",
            )
        return kg_adj, prep_file
예제 #4
0
파일: run.py 프로젝트: lorr1/bootleg
def setup(config, run_config_path=None):
    """
    Setup distributed backend and save configuration files.
    Args:
        config: config
        run_config_path: path for original run config

    Returns:
    """
    # torch.multiprocessing.set_sharing_strategy("file_system")
    # spawn method must be fork to work with Meta.config
    torch.multiprocessing.set_start_method("fork", force=True)
    """
    ulimit -n 500000
    python3 -m torch.distributed.launch --nproc_per_node=2  bootleg/run.py --config_script ...
    """
    log_level = logging.getLevelName(config.run_config.log_level.upper())
    emmental.init(
        log_dir=config["meta_config"]["log_path"],
        config=config,
        use_exact_log_path=config["meta_config"]["use_exact_log_path"],
        local_rank=config.learner_config.local_rank,
        level=log_level,
    )
    log = logging.getLogger()
    # Remove streaming handlers and use rich
    log.handlers = [
        h for h in log.handlers if not type(h) is logging.StreamHandler
    ]
    log.addHandler(RichHandler())
    # Set up distributed backend
    emmental.Meta.init_distributed_backend()

    cmd_msg = " ".join(sys.argv)
    # Log configuration into filess
    if config.learner_config.local_rank in [0, -1]:
        write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg)
        dump_yaml_file(f"{emmental.Meta.log_path}/parsed_config.yaml",
                       emmental.Meta.config)
        # Dump the run config (does not contain defaults)
        if run_config_path is not None:
            dump_yaml_file(
                f"{emmental.Meta.log_path}/run_config.yaml",
                load_yaml_file(run_config_path),
            )

    log_rank_0_info(logger, f"COMMAND: {cmd_msg}")
    log_rank_0_info(
        logger,
        f"Saving config to {emmental.Meta.log_path}/parsed_config.yaml")

    git_hash = "Not able to retrieve git hash"
    try:
        git_hash = subprocess.check_output([
            "git", "log", "-n", "1", "--pretty=tformat:%h-%ad", "--date=short"
        ]).strip()
    except subprocess.CalledProcessError:
        pass
    log_rank_0_info(logger, f"Git Hash: {git_hash}")
예제 #5
0
def try_rmtree(rm_dir):
    """In the case a resource is open, rmtree will fail. This retries to rmtree
    after 1 second waits for 5 times.

    Args:
        rm_dir: directory to remove

    Returns:
    """
    num_retries = 0
    max_retries = 5
    while num_retries < max_retries:
        try:
            shutil.rmtree(rm_dir)
            break
        except OSError:
            time.sleep(1)
            num_retries += 1
            if num_retries >= max_retries:
                log_rank_0_info(
                    logger,
                    f"{rm_dir} was not able to be deleted. This is okay but will have to manually be removed.",
                )
예제 #6
0
    def __init__(
        self,
        main_args,
        dataset,
        use_weak_label,
        entity_symbols,
        dataset_threads,
        split="train",
    ):
        global_start = time.time()
        log_rank_0_info(logger,
                        f"Building slice dataset for {split} from {dataset}.")
        spawn_method = main_args.run_config.spawn_method
        data_config = main_args.data_config
        orig_spawn = multiprocessing.get_start_method()
        multiprocessing.set_start_method(spawn_method, force=True)
        self.slice_names = data_utils.get_eval_slices(data_config.eval_slices)
        self.get_slice_dt = lambda max_a2p: np.dtype([
            ("sent_idx", int),
            ("subslice_idx", int),
            ("alias_slice_incidence", int, (max_a2p, )),
            ("prob_labels", float, (max_a2p, )),
        ])
        self.get_storage = lambda max_a2p: np.dtype(
            [(slice_name, self.get_slice_dt(max_a2p))
             for slice_name in self.slice_names])
        # Folder for all mmap saved files
        save_dataset_folder = data_utils.get_save_data_folder(
            data_config, use_weak_label, dataset)
        utils.ensure_dir(save_dataset_folder)
        # Folder for temporary output files
        temp_output_folder = os.path.join(data_config.data_dir,
                                          data_config.data_prep_dir,
                                          f"prep_{split}_slice_files")
        utils.ensure_dir(temp_output_folder)
        # Input step 1
        create_ex_indir = os.path.join(temp_output_folder,
                                       "create_examples_input")
        utils.ensure_dir(create_ex_indir)
        # Input step 2
        create_ex_outdir = os.path.join(temp_output_folder,
                                        "create_examples_output")
        utils.ensure_dir(create_ex_outdir)
        # Meta data saved files
        meta_file = os.path.join(temp_output_folder, "meta_data.json")
        # File for standard training data
        hash = hashlib.sha1(str(
            self.slice_names).encode("UTF-8")).hexdigest()[:10]
        self.save_dataset_name = os.path.join(save_dataset_folder,
                                              f"ned_slices_{hash}.bin")
        self.save_data_config_name = os.path.join(save_dataset_folder,
                                                  "ned_slices_config.json")

        # =======================================================================================
        # SLICE DATA
        # =======================================================================================
        log_rank_0_debug(logger, "Loading dataset...")
        log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists")
        if data_config.overwrite_preprocessed_data or (not os.path.exists(
                self.save_dataset_name)):
            st_time = time.time()
            try:
                log_rank_0_info(
                    logger,
                    f"Building dataset from scratch. Saving to {save_dataset_folder}",
                )
                create_examples(
                    dataset,
                    create_ex_indir,
                    create_ex_outdir,
                    meta_file,
                    data_config,
                    dataset_threads,
                    self.slice_names,
                    use_weak_label,
                    split,
                )
                max_alias2pred = utils.load_json_file(
                    meta_file)["max_alias2pred"]
                convert_examples_to_features_and_save(
                    meta_file,
                    dataset_threads,
                    self.slice_names,
                    self.save_dataset_name,
                    self.get_storage(max_alias2pred),
                )
                utils.dump_json_file(self.save_data_config_name,
                                     {"max_alias2pred": max_alias2pred})

                log_rank_0_debug(
                    logger,
                    f"Finished prepping data in {time.time() - st_time}")
            except Exception as e:
                tb = traceback.TracebackException.from_exception(e)
                logger.error(e)
                logger.error("\n".join(tb.stack.format()))
                shutil.rmtree(save_dataset_folder, ignore_errors=True)
                raise

        log_rank_0_info(
            logger,
            f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}",
        )
        max_alias2pred = utils.load_json_file(
            self.save_data_config_name)["max_alias2pred"]
        self.data, self.sent_to_row_id_dict = self.build_data_dict(
            self.save_dataset_name, self.get_storage(max_alias2pred))
        assert len(self.data) > 0
        assert len(self.sent_to_row_id_dict) > 0
        log_rank_0_debug(logger, f"Removing temporary output files")
        shutil.rmtree(temp_output_folder, ignore_errors=True)
        # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to
        # be correctly passed in the collate_fn.
        multiprocessing.set_start_method(orig_spawn, force=True)
        log_rank_0_info(
            logger,
            f"Final slice data initialization time from {split} is {time.time() - global_start}s",
        )
예제 #7
0
def convert_examples_to_features_and_save(meta_file, dataset_threads,
                                          slice_names, save_dataset_name,
                                          storage):
    """Converts the prepped examples into input features and saves in memmap
    files. These are used in the __get_item__ method.

    Args:
        meta_file: metadata file where input file paths are saved
        dataset_threads: number of threads
        slice_names: list of slice names to evaluation on
        save_dataset_name: data file name to save
        storage: data storage type (for memmap)

    Returns:
    """
    log_rank_0_debug(logger, "Starting to extract subsentences")
    start = time.time()
    num_processes = min(dataset_threads,
                        int(0.8 * multiprocessing.cpu_count()))

    log_rank_0_info(
        logger,
        f"Starting to build and save features with {num_processes} threads")

    log_rank_0_debug(logger, f"Counting lines")
    total_input = utils.load_json_file(meta_file)["num_mentions"]
    max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"]
    files_and_counts = utils.load_json_file(meta_file)["files_and_counts"]

    # IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before
    # being opened in r+ mode by workers
    memmap_file = np.memmap(save_dataset_name,
                            dtype=storage,
                            mode="w+",
                            shape=(total_input, ),
                            order="C")
    # Save -1 in sent_idx to check that things are loaded correctly later
    memmap_file[slice_names[0]]["sent_idx"][:] = -1

    input_args = []
    # Saves where in memap file to start writing
    offset = 0
    for i, in_file_name in enumerate(files_and_counts.keys()):
        input_args.append({
            "file_name": in_file_name,
            "in_file_lines": files_and_counts[in_file_name],
            "save_file_offset": offset,
            "ex_print_mod": int(np.ceil(total_input / 20)),
            "slice_names": slice_names,
            "max_alias2pred": max_alias2pred,
        })
        offset += files_and_counts[in_file_name]

    if num_processes == 1:
        assert len(input_args) == 1
        total_output = convert_examples_to_features_and_save_single(
            input_args[0], memmap_file)
    else:
        log_rank_0_debug(
            logger,
            "Initializing pool. This make take a few minutes.",
        )
        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=convert_examples_to_features_and_save_initializer,
            initargs=[save_dataset_name, storage],
        )

        total_output = 0
        for res in pool.imap_unordered(
                convert_examples_to_features_and_save_hlp, input_args,
                chunksize=1):
            total_output += res
        pool.close()

    # Verify that sentences are unique and saved correctly
    mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r")
    all_uniq_ids = set()
    for i in tqdm(range(total_input), desc="Checking sentence uniqueness"):
        assert (mmap_file[slice_names[0]]["sent_idx"][i] !=
                -1), f"Index {i} has -1 sent idx"
        uniq_id = str(
            f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}"
        )
        assert (uniq_id not in all_uniq_ids
                ), f"Idx {uniq_id} is not unique and already in data"
        all_uniq_ids.add(uniq_id)

    log_rank_0_debug(
        logger,
        f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. "
        f"Total lines kept {total_output}",
    )
    return
예제 #8
0
def create_examples(
    dataset,
    create_ex_indir,
    create_ex_outdir,
    meta_file,
    data_config,
    dataset_threads,
    slice_names,
    use_weak_label,
    split,
):
    """Creates examples from the raw input data.

    Args:
        dataset: dataset file
        create_ex_indir: temporary directory where input files are stored
        create_ex_outdir: temporary directory to store output files from method
        meta_file: metadata file to save the file names/paths for the next step in prep pipeline
        data_config: data config
        dataset_threads: number of threads
        slice_names: list of slices to evaluate on
        use_weak_label: whether to use weak labeling or not
        split: data split

    Returns:
    """
    log_rank_0_debug(logger, "Starting to extract subsentences")
    start = time.time()
    num_processes = min(dataset_threads,
                        int(0.8 * multiprocessing.cpu_count()))

    log_rank_0_debug(logger, f"Counting lines")
    total_input = sum(1 for _ in open(dataset))
    if num_processes == 1:
        out_file_name = os.path.join(create_ex_outdir,
                                     os.path.basename(dataset))
        constants_dict = {
            "slice_names": slice_names,
            "use_weak_label": use_weak_label,
            "max_aliases": data_config.max_aliases,
            "split": split,
            "train_in_candidates": data_config.train_in_candidates,
        }
        files_and_counts = {}
        res = create_examples_single(dataset, total_input, out_file_name,
                                     constants_dict)
        total_output = res["total_lines"]
        max_alias2pred = res["max_alias2pred"]
        files_and_counts[res["output_filename"]] = res["total_lines"]
    else:
        log_rank_0_info(
            logger,
            f"Strating to extract examples with {num_processes} threads")
        log_rank_0_debug(
            logger, "Parallelizing with " + str(num_processes) + " threads.")
        chunk_input = int(np.ceil(total_input / num_processes))
        log_rank_0_debug(
            logger,
            f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines",
        )
        total_input_from_chunks, input_files_dict = utils.chunk_file(
            dataset, create_ex_indir, chunk_input)

        input_files = list(input_files_dict.keys())
        input_file_lines = [input_files_dict[k] for k in input_files]
        output_files = [
            in_file_name.replace(create_ex_indir, create_ex_outdir)
            for in_file_name in input_files
        ]
        assert (
            total_input == total_input_from_chunks
        ), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}"
        log_rank_0_debug(logger, f"Done chunking files")

        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=create_examples_initializer,
            initargs=[
                data_config,
                slice_names,
                use_weak_label,
                data_config.max_aliases,
                split,
                data_config.train_in_candidates,
            ],
        )
        total_output = 0
        max_alias2pred = 0
        input_args = list(zip(input_files, input_file_lines, output_files))
        # Store output files and counts for saving in next step
        files_and_counts = {}
        for res in pool.imap_unordered(create_examples_hlp,
                                       input_args,
                                       chunksize=1):
            total_output += res["total_lines"]
            max_alias2pred = max(max_alias2pred, res["max_alias2pred"])
            files_and_counts[res["output_filename"]] = res["total_lines"]
        pool.close()
    utils.dump_json_file(
        meta_file,
        {
            "num_mentions": total_output,
            "files_and_counts": files_and_counts,
            "max_alias2pred": max_alias2pred,
        },
    )
    log_rank_0_debug(
        logger,
        f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. "
        f"Total lines kept {total_output}.",
    )
    return
예제 #9
0
파일: data.py 프로젝트: pombredanne/bootleg
def get_dataloaders(
    args,
    tasks,
    splits,
    entity_symbols,
    batch_on_the_fly_kg_adj,
):
    """Gets the dataloaders.

    Args:
        args: main args
        tasks: task names
        splits: data splits to generate dataloaders for
        entity_symbols: entity symbols
        batch_on_the_fly_kg_adj: kg embeddings metadata for the __get_item__ method (see get_dataloader_embeddings)

    Returns: list of dataloaders
    """
    task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}
    is_bert = len(args.data_config.word_embedding.bert_model) > 0
    tokenizer = BertTokenizer.from_pretrained(
        args.data_config.word_embedding.bert_model,
        do_lower_case=True
        if "uncased" in args.data_config.word_embedding.bert_model else False,
        cache_dir=args.data_config.word_embedding.cache_dir,
    )

    datasets = {}
    for split in splits:
        dataset_path = os.path.join(args.data_config.data_dir,
                                    args.data_config[f"{split}_dataset"].file)
        datasets[split] = BootlegDataset(
            main_args=args,
            name=f"Bootleg",
            dataset=dataset_path,
            use_weak_label=args.data_config[f"{split}_dataset"].use_weak_label,
            tokenizer=tokenizer,
            entity_symbols=entity_symbols,
            dataset_threads=args.run_config.dataset_threads,
            split=split,
            is_bert=is_bert,
            batch_on_the_fly_kg_adj=batch_on_the_fly_kg_adj,
        )

    dataloaders = []
    for split, dataset in datasets.items():
        if split in args.learner_config.train_split:
            dataset_sampler = (RandomSampler(dataset)
                               if Meta.config["learner_config"]["local_rank"]
                               == -1 else DistributedSampler(dataset))
        else:
            dataset_sampler = None
            if Meta.config["learner_config"]["local_rank"] != -1:
                log_rank_0_info(
                    logger,
                    f"You are using distributed computing for eval. We are not using a distributed sampler. "
                    f"Please use DataParallel and not DDP.",
                )
        dataloaders.append(
            EmmentalDataLoader(
                task_to_label_dict=task_to_label_dict,
                dataset=dataset,
                sampler=dataset_sampler,
                split=split,
                collate_fn=bootleg_collate_fn,
                batch_size=args.train_config.batch_size
                if split in args.learner_config.train_split
                or args.run_config.eval_batch_size is None else
                args.run_config.eval_batch_size,
                num_workers=args.run_config.dataloader_threads,
                pin_memory=False,
            ))
        log_rank_0_info(
            logger,
            f"Built dataloader for {split} set with {len(dataset)} and {args.run_config.dataloader_threads} threads "
            f"samples (Shuffle={split in args.learner_config.train_split}, "
            f"Batch size={dataloaders[-1].batch_size}).",
        )

    return dataloaders
예제 #10
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TypeEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "max_types",
            "type_dim",
            "type_labels",
            "type_vocab",
            "merge_func",
            "attn_hidden_size",
            "regularize_mapping",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "max_types"
            in emb_args), "Type embedding requires max_types to be set in args"
        assert (
            "type_dim"
            in emb_args), "Type embedding requires type_dim to be set in args"
        assert (
            "type_labels" in emb_args
        ), "Type embedding requires type_labels to be set in args. A Dict from QID -> TypeId or TypeName"
        assert (
            "type_vocab" in emb_args
        ), "Type embedding requires type_vocab to be set in args. A Dict from TypeName -> TypeId"
        assert (self.cpu is False
                ), f"We don't support putting type embeddings on CPU right now"
        self.merge_func = self.average_types
        self.orig_dim = emb_args.type_dim
        self.add_attn = None
        # Function for merging multiple types
        if "merge_func" in emb_args:
            assert emb_args.merge_func in [
                "average",
                "addattn",
            ], (f"{key}: You have set the type merge_func to be {emb_args.merge_func} but"
                f" that is not in the allowable list of [average, addattn]")
            if emb_args.merge_func == "addattn":
                if "attn_hidden_size" in emb_args:
                    attn_hidden_size = emb_args.attn_hidden_size
                else:
                    attn_hidden_size = 100
                # Softmax of types using the sentence context
                self.add_attn = PositionAwareAttention(
                    input_size=self.orig_dim,
                    attn_size=attn_hidden_size,
                    feature_size=0)
                self.merge_func = self.add_attn_merge
        self.max_types = emb_args.max_types
        (
            eid2typeids_table,
            self.type2row_dict,
            num_types_with_unk,
            self.prep_file,
        ) = self.prep(
            data_config=main_args.data_config,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
        )
        self.register_buffer("eid2typeids_table",
                             eid2typeids_table,
                             persistent=False)
        # self.eid2typeids_table.requires_grad = False
        self.num_types_with_pad_and_unk = num_types_with_unk + 1

        # Regularization mapping goes from typeid to 2d dropout percent
        if "regularize_mapping" in emb_args:
            typeid2reg = torch.zeros(self.num_types_with_pad_and_unk)
        else:
            typeid2reg = None

        if not self.from_pretrained:
            if "regularize_mapping" in emb_args:
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    logger.warning(
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?"
                    )
                log_rank_0_info(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                typeid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    self.num_types_with_pad_and_unk,
                    self.type2row_dict,
                    emb_args.regularize_mapping,
                )
        self.register_buffer("typeid2reg", typeid2reg)
        assert self.eid2typeids_table.shape[1] == emb_args.max_types, (
            f"Something went wrong with loading type file."
            f" The given max types {emb_args.max_types} does not match that "
            f"of type table {self.eid2typeids_table.shape[1]}")
        log_rank_0_debug(
            logger,
            f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. "
            f"Setting merge_func to be {self.merge_func.__name__} in type emb.",
        )
예제 #11
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file