Ejemplo n.º 1
0
def generate_entity_path_files(
    entity_storage: AbstractEntityStorage,
    entities_by_type: Dict[str, Dictionary],
    relation_type_storage: AbstractRelationTypeStorage,
    relation_types: Dictionary,
    dynamic_relations: bool,
) -> None:
    print(
        f"Preparing counts and dictionaries for entities and relation types:")
    entity_storage.prepare()
    relation_type_storage.prepare()

    for entity_name, entities in entities_by_type.items():
        for part in range(entities.num_parts):
            print(f"- Writing count of entity type {entity_name} "
                  f"and partition {part}")
            entity_storage.save_count(entity_name, part,
                                      entities.part_size(part))
            entity_storage.save_names(entity_name, part,
                                      entities.get_part_list(part))

    if dynamic_relations:
        print("- Writing count of dynamic relations")
        relation_type_storage.save_count(relation_types.size())
        relation_type_storage.save_names(relation_types.get_list())
Ejemplo n.º 2
0
def make_tsv_for_relation_types(
    model: MultiRelationEmbedder,
    relation_type_storage: AbstractRelationTypeStorage,
    relation_types_tf: TextIO,
) -> None:
    print("Writing relation type parameters...")
    relation_types = relation_type_storage.load_names()
    if model.num_dynamic_rels > 0:
        (rel_t_config, ) = model.relations
        op_name = rel_t_config.operator
        (lhs_operator, ) = model.lhs_operators
        (rhs_operator, ) = model.rhs_operators
        for side, operator in [("lhs", lhs_operator), ("rhs", rhs_operator)]:
            for param_name, all_params in operator.named_parameters():
                for rel_t_name, param in zip(relation_types, all_params):
                    shape = "x".join(f"{d}" for d in param.shape)
                    write(
                        relation_types_tf,
                        (rel_t_name, side, op_name, param_name, shape),
                        param,
                    )
    else:
        for rel_t_name, rel_t_config, operator in zip(relation_types,
                                                      model.relations,
                                                      model.rhs_operators):
            if rel_t_name != rel_t_config.name:
                raise ValueError(
                    f"Mismatch in relations names: got {rel_t_name} in the "
                    f"dictionary and {rel_t_config.name} in the config.")
            op_name = rel_t_config.operator
            for param_name, param in operator.named_parameters():
                shape = "x".join(f"{d}" for d in param.shape)
                write(
                    relation_types_tf,
                    (rel_t_name, "rhs", op_name, param_name, shape),
                    param,
                )

    relation_types_output_filename = getattr(relation_types_tf, "name",
                                             "the output file")
    print(
        f"Done exporting relation type data to {relation_types_output_filename}"
    )