def generate_entity_path_files( entity_storage: AbstractEntityStorage, entities_by_type: Dict[str, Dictionary], relation_type_storage: AbstractRelationTypeStorage, relation_types: Dictionary, dynamic_relations: bool, ) -> None: print( f"Preparing counts and dictionaries for entities and relation types:") entity_storage.prepare() relation_type_storage.prepare() for entity_name, entities in entities_by_type.items(): for part in range(entities.num_parts): print(f"- Writing count of entity type {entity_name} " f"and partition {part}") entity_storage.save_count(entity_name, part, entities.part_size(part)) entity_storage.save_names(entity_name, part, entities.get_part_list(part)) if dynamic_relations: print("- Writing count of dynamic relations") relation_type_storage.save_count(relation_types.size()) relation_type_storage.save_names(relation_types.get_list())
def make_tsv_for_relation_types( model: MultiRelationEmbedder, relation_type_storage: AbstractRelationTypeStorage, relation_types_tf: TextIO, ) -> None: print("Writing relation type parameters...") relation_types = relation_type_storage.load_names() if model.num_dynamic_rels > 0: (rel_t_config, ) = model.relations op_name = rel_t_config.operator (lhs_operator, ) = model.lhs_operators (rhs_operator, ) = model.rhs_operators for side, operator in [("lhs", lhs_operator), ("rhs", rhs_operator)]: for param_name, all_params in operator.named_parameters(): for rel_t_name, param in zip(relation_types, all_params): shape = "x".join(f"{d}" for d in param.shape) write( relation_types_tf, (rel_t_name, side, op_name, param_name, shape), param, ) else: for rel_t_name, rel_t_config, operator in zip(relation_types, model.relations, model.rhs_operators): if rel_t_name != rel_t_config.name: raise ValueError( f"Mismatch in relations names: got {rel_t_name} in the " f"dictionary and {rel_t_config.name} in the config.") op_name = rel_t_config.operator for param_name, param in operator.named_parameters(): shape = "x".join(f"{d}" for d in param.shape) write( relation_types_tf, (rel_t_name, "rhs", op_name, param_name, shape), param, ) relation_types_output_filename = getattr(relation_types_tf, "name", "the output file") print( f"Done exporting relation type data to {relation_types_output_filename}" )