Ejemplo n.º 1
0
 def test_prune_to_entities(self):
     qid2typenames = {
         "Q123": ["animal"],
         "Q345": ["dog"],
         "Q567": ["animal", "animall", "drop"],
         "Q789": [],
     }
     max_types = 3
     type_symbols = TypeSymbols(qid2typenames, max_types=max_types, edit_mode=True)
     type_symbols.prune_to_entities({"Q123", "Q345"})
     gold_qid2typenames = {
         "Q123": ["animal"],
         "Q345": ["dog"],
     }
     gold_qid2typeid = {
         "Q123": [1],
         "Q345": [3],
     }
     gold_typename2qids = {
         "animal": {"Q123"},
         "animall": set(),
         "dog": {"Q345"},
         "drop": set(),
     }
     self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
     self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
     self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)
Ejemplo n.º 2
0
 def test_reidentify_entity(self):
     qid2typenames = {
         "Q123": ["animal"],
         "Q345": ["dog"],
         "Q567": ["animal", "animall", "drop"],
         "Q789": [],
     }
     max_types = 3
     type_symbols = TypeSymbols(qid2typenames, max_types=max_types, edit_mode=True)
     type_symbols.reidentify_entity("Q567", "Q911")
     gold_qid2typenames = {
         "Q123": ["animal"],
         "Q345": ["dog"],
         "Q911": ["animal", "animall", "drop"],
         "Q789": [],
     }
     gold_qid2typeid = {
         "Q123": [1],
         "Q345": [3],
         "Q911": [1, 2, 4],
         "Q789": [],
     }
     gold_typename2qids = {
         "animal": {"Q123", "Q911"},
         "animall": {"Q911"},
         "dog": {"Q345"},
         "drop": {"Q911"},
     }
     self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
     self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
     self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)
Ejemplo n.º 3
0
def main():
    args = parse_args()
    logging.info(json.dumps(args, indent=4))
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_dir, args.entity_symbols_dir))
    train_file = os.path.join(args.data_dir, args.train_file)
    save_dir = os.path.join(args.save_dir, "stats")
    logging.info(f"Will save data to {save_dir}")
    utils.ensure_dir(save_dir)
    # compute_histograms(save_dir, entity_symbols)
    compute_occurrences(save_dir,
                        train_file,
                        entity_symbols,
                        args.lower,
                        args.strip,
                        num_workers=args.num_workers)
    if not args.no_types:
        type_symbols = TypeSymbols(
            entity_symbols=entity_symbols,
            emb_dir=args.emb_dir,
            max_types=args.max_types,
            emb_file="hyena_type_emb.pkl",
            type_vocab_file="hyena_type_graph.vocab.pkl",
            type_file="hyena_types.txt")
        compute_type_occurrences(save_dir, "orig", entity_symbols,
                                 type_symbols.qid2typenames, train_file)
Ejemplo n.º 4
0
    def test_type_init(self):
        qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
        }
        max_types = 2
        type_symbols = TypeSymbols(qid2typenames, max_types=max_types)

        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall"],
            "Q789": [],
        }
        gold_qid2typeid = {"Q123": [1], "Q345": [3], "Q567": [1, 2], "Q789": []}
        gold_type_vocab = {"animal": 1, "animall": 2, "dog": 3, "drop": 4}
        gold_type_vocab_inv = {1: "animal", 2: "animall", 3: "dog", 4: "drop"}

        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertIsNone(type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)
        self.assertDictEqual(gold_type_vocab, type_symbols._type_vocab)
        self.assertDictEqual(gold_type_vocab_inv, type_symbols._type_vocab_inv)

        max_types = 4
        type_symbols = TypeSymbols(qid2typenames, max_types=max_types, edit_mode=True)

        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
        }
        gold_typename2qids = {
            "animal": {"Q123", "Q567"},
            "animall": {"Q567"},
            "dog": {"Q345"},
            "drop": {"Q567"},
        }
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
Ejemplo n.º 5
0
    def test_type_load_and_save(self):
        qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
        }
        max_types = 2
        type_symbols = TypeSymbols(qid2typenames, max_types=max_types)
        type_symbols.save(self.save_dir, prefix="test")
        type_symbols_2 = TypeSymbols.load_from_cache(self.save_dir, prefix="test")

        self.assertEqual(type_symbols_2.max_types, type_symbols.max_types)
        self.assertDictEqual(type_symbols_2._qid2typenames, type_symbols._qid2typenames)
        self.assertIsNone(type_symbols._typename2qids)
        self.assertIsNone(type_symbols_2._typename2qids)
        self.assertDictEqual(type_symbols_2._qid2typeid, type_symbols._qid2typeid)
        self.assertDictEqual(type_symbols_2._type_vocab, type_symbols._type_vocab)
        self.assertDictEqual(
            type_symbols_2._type_vocab_inv, type_symbols._type_vocab_inv
        )
Ejemplo n.º 6
0
    def load_from_jsonl(
        cls,
        profile_file,
        max_candidates=30,
        max_types=10,
        max_kg_connections=100,
        edit_mode=False,
    ):
        """Loads an entity profile from the raw jsonl file. Each line is a JSON
        object with entity metadata.

        Example object::

            {
                "entity_id": "C000",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            }

        Args:
            profile_file: file where jsonl data lives
            max_candidates: maximum entity candidates
            max_types: maximum types per entity
            max_kg_connections: maximum KG connections per entity
            edit_mode: edit mode

        Returns: entity profile object
        """
        qid2title, alias2qids, type_systems, qid2relations = cls._read_profile_file(
            profile_file
        )
        entity_symbols = EntitySymbols(
            alias2qids=alias2qids,
            qid2title=qid2title,
            max_candidates=max_candidates,
            edit_mode=edit_mode,
        )

        all_type_symbols = {
            ty_name: TypeSymbols(
                qid2typenames=type_map, max_types=max_types, edit_mode=edit_mode
            )
            for ty_name, type_map in type_systems.items()
        }
        kg_symbols = KGSymbols(
            qid2relations, max_connections=max_kg_connections, edit_mode=edit_mode
        )
        return cls(entity_symbols, all_type_symbols, kg_symbols, edit_mode)
Ejemplo n.º 7
0
    def test_add_entity(self):
        qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
        }
        max_types = 3
        type_symbols = TypeSymbols(qid2typenames, max_types=max_types, edit_mode=True)
        # Check the invalid type fails (this type isn't in our set of types)
        with self.assertRaises(ValueError) as context:
            type_symbols.add_entity("Q910", ["annnimal", "animal", "dog", "drop"])
        assert type(context.exception) is ValueError

        # Add to a previously empty QID
        type_symbols.add_entity("Q910", ["animall", "animal", "dog", "drop"])
        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
            "Q910": ["animall", "animal", "dog"],  # Max types limits new types added
        }
        gold_qid2typeid = {
            "Q123": [1],
            "Q345": [3],
            "Q567": [1, 2, 4],
            "Q789": [],
            "Q910": [2, 1, 3],
        }
        gold_typename2qids = {
            "animal": {"Q123", "Q567", "Q910"},
            "animall": {"Q567", "Q910"},
            "dog": {"Q345", "Q910"},
            "drop": {"Q567"},
        }
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)
Ejemplo n.º 8
0
    def load_from_cache(
        cls,
        load_dir,
        edit_mode=False,
        verbose=False,
        no_kg=False,
        no_type=False,
        type_systems_to_load=None,
    ):
        """Loaded a pre-saved profile.

        Args:
            load_dir: load directory
            edit_mode: edit mode flag, default False
            verbose: verbose flag, default False
            no_kg: load kg or not flag, default False
            no_type: load types or not flag, default False. If True, this will ignore type_systems_to_load.
            type_systems_to_load: list of type systems to load, default is None which means all types systems

        Returns: entity profile object
        """
        # Check type system input
        load_dir = Path(load_dir)
        type_subfolder = load_dir / TYPE_SUBFOLDER
        if type_systems_to_load is not None:
            if not isinstance(type_systems_to_load, list):
                raise ValueError(
                    f"`type_systems` must be a list of subfolders in {type_subfolder}"
                )
            for sys in type_systems_to_load:
                if sys not in list([p.name for p in type_subfolder.iterdir()]):
                    raise ValueError(
                        f"`type_systems` must be a list of subfolders in {type_subfolder}. {sys} is not one."
                    )

        if verbose:
            print("Loading Entity Symbols")
        entity_symbols = EntitySymbols.load_from_cache(
            load_dir / ENTITY_SUBFOLDER,
            edit_mode=edit_mode,
            verbose=verbose,
        )
        if no_type:
            print(
                f"Not loading type information. We will act as if there is no types associated with any entity "
                f"and will not modify the types in any way, even if calling `add`."
            )
        type_sys_dict = {}
        for fold in type_subfolder.iterdir():
            if ((not no_type) and (type_systems_to_load is None
                                   or fold.name in type_systems_to_load)
                    and (fold.is_dir())):
                if verbose:
                    print(f"Loading Type Symbols from {fold}")
                type_sys_dict[fold.name] = TypeSymbols.load_from_cache(
                    type_subfolder / fold.name,
                    edit_mode=edit_mode,
                    verbose=verbose,
                )
        if verbose:
            print(f"Loading KG Symbols")
        if no_kg:
            print(
                f"Not loading KG information. We will act as if there is not KG connections between entities. "
                f"We will not modify the KG information in any way, even if calling `add`."
            )
        kg_symbols = None
        if not no_kg:
            kg_symbols = KGSymbols.load_from_cache(
                load_dir / KG_SUBFOLDER,
                edit_mode=edit_mode,
                verbose=verbose,
            )
        return cls(entity_symbols, type_sys_dict, kg_symbols, edit_mode,
                   verbose)
Ejemplo n.º 9
0
    def test_type_add_remove_typemap(self):
        qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
        }
        max_types = 3
        type_symbols = TypeSymbols(qid2typenames, max_types=max_types, edit_mode=False)
        # Check if fails with edit_mode = False
        with self.assertRaises(AttributeError) as context:
            type_symbols.add_type("Q789", "animal")
        assert type(context.exception) is AttributeError

        type_symbols = TypeSymbols(qid2typenames, max_types=max_types, edit_mode=True)
        # Check the invalid type fails (this type isn't in our set of types)
        with self.assertRaises(ValueError) as context:
            type_symbols.add_type("Q789", "annnimal")
        assert type(context.exception) is ValueError

        # Add to a previously empty QID
        type_symbols.add_type("Q789", "animal")
        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": ["animal"],
        }
        gold_qid2typeid = {"Q123": [1], "Q345": [3], "Q567": [1, 2, 4], "Q789": [1]}
        gold_typename2qids = {
            "animal": {"Q123", "Q567", "Q789"},
            "animall": {"Q567"},
            "dog": {"Q345"},
            "drop": {"Q567"},
        }
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)

        # Check that nothing happens with relation pair that doesn't exist and the operation goes through
        type_symbols.remove_type("Q345", "animal")
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)

        # Now actually remove something
        type_symbols.remove_type("Q789", "animal")
        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "drop"],
            "Q789": [],
        }
        gold_qid2typeid = {"Q123": [1], "Q345": [3], "Q567": [1, 2, 4], "Q789": []}
        gold_typename2qids = {
            "animal": {"Q123", "Q567"},
            "animall": {"Q567"},
            "dog": {"Q345"},
            "drop": {"Q567"},
        }
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)

        # Add to a full QID where we must replace. We do not bring back the old type if we remove the replace one.
        type_symbols.add_type("Q567", "dog")
        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall", "dog"],
            "Q789": [],
        }
        gold_qid2typeid = {"Q123": [1], "Q345": [3], "Q567": [1, 2, 3], "Q789": []}
        gold_typename2qids = {
            "animal": {"Q123", "Q567"},
            "animall": {"Q567"},
            "dog": {"Q345", "Q567"},
            "drop": set(),
        }
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)

        type_symbols.remove_type("Q567", "dog")
        gold_qid2typenames = {
            "Q123": ["animal"],
            "Q345": ["dog"],
            "Q567": ["animal", "animall"],
            "Q789": [],
        }
        gold_qid2typeid = {"Q123": [1], "Q345": [3], "Q567": [1, 2], "Q789": []}
        gold_typename2qids = {
            "animal": {"Q123", "Q567"},
            "animall": {"Q567"},
            "dog": {"Q345"},
            "drop": set(),
        }
        self.assertDictEqual(gold_qid2typenames, type_symbols._qid2typenames)
        self.assertDictEqual(gold_typename2qids, type_symbols._typename2qids)
        self.assertDictEqual(gold_qid2typeid, type_symbols._qid2typeid)