Esempio n. 1
0
def make_tsv_for_entities(
    model: MultiRelationEmbedder,
    checkpoint_manager: CheckpointManager,
    entity_storage: AbstractEntityStorage,
    entities_tf: TextIO,
) -> None:
    print("Writing entity embeddings...")
    for ent_t_name, ent_t_config in model.entities.items():
        for partition in range(ent_t_config.num_partitions):
            print(f"Reading embeddings for entity type {ent_t_name} partition "
                  f"{partition} from checkpoint...")
            entities = entity_storage.load_names(ent_t_name, partition)
            embeddings, _ = checkpoint_manager.read(ent_t_name, partition)

            if model.global_embs is not None:
                embeddings += model.global_embs[model.EMB_PREFIX + ent_t_name]

            print(f"Writing embeddings for entity type {ent_t_name} partition "
                  f"{partition} to output file...")
            for ix in range(len(embeddings)):
                write(entities_tf, (entities[ix], ), embeddings[ix])
                if (ix + 1) % 5000 == 0:
                    print(
                        f"- Processed {ix+1}/{len(embeddings)} entities so far..."
                    )
            print(f"- Processed all {len(embeddings)} entities")

    entities_output_filename = getattr(entities_tf, "name", "the output file")
    print(f"Done exporting entity data to {entities_output_filename}")
Esempio n. 2
0
def generate_entity_path_files(
    entity_storage: AbstractEntityStorage,
    entities_by_type: Dict[str, Dictionary],
    relation_type_storage: AbstractRelationTypeStorage,
    relation_types: Dictionary,
    dynamic_relations: bool,
) -> None:
    print(
        f"Preparing counts and dictionaries for entities and relation types:")
    entity_storage.prepare()
    relation_type_storage.prepare()

    for entity_name, entities in entities_by_type.items():
        for part in range(entities.num_parts):
            print(f"- Writing count of entity type {entity_name} "
                  f"and partition {part}")
            entity_storage.save_count(entity_name, part,
                                      entities.part_size(part))
            entity_storage.save_names(entity_name, part,
                                      entities.get_part_list(part))

    if dynamic_relations:
        print("- Writing count of dynamic relations")
        relation_type_storage.save_count(relation_types.size())
        relation_type_storage.save_names(relation_types.get_list())