def get_partitioned_types( config: ConfigSchema, side: Side) -> Tuple[int, Set[EntityName], Set[EntityName]]: """Return the number of partitions on a given side and the entity types Each of the entity types that appear on the given side (LHS or RHS) of a relation type is split into some number of partitions. The ones that are split into one partition are called "unpartitioned" and behave as if all of their entities belonged to all buckets. The other ones are the "properly" partitioned ones. Currently, they must all be partitioned into the same number of partitions. This function returns that number, the names of the unpartitioned entity types and the names of the properly partitioned entity types. """ entity_names_by_num_parts: Dict[int, Set[EntityName]] = defaultdict(set) for relation_config in config.relations: entity_name = side.pick(relation_config.lhs, relation_config.rhs) entity_config = config.entities[entity_name] entity_names_by_num_parts[entity_config.num_partitions].add( entity_name) unpartitioned_entity_names = entity_names_by_num_parts.pop(1, set()) if len(entity_names_by_num_parts) == 0: return 1, unpartitioned_entity_names, set() if len(entity_names_by_num_parts) > 1: raise RuntimeError("Currently num_partitions must be a single " "value across all partitioned entities.") (num_partitions, partitioned_entity_names), = entity_names_by_num_parts.items() return num_partitions, unpartitioned_entity_names, partitioned_entity_names
def set_embeddings(self, entity: str, side: Side, weights: nn.Parameter) -> None: if self.entities[entity].featurized: emb = FeaturizedEmbedding(weights, max_norm=self.max_norm) else: emb = SimpleEmbedding(weights, max_norm=self.max_norm) side.pick(self.lhs_embs, self.rhs_embs)[self.EMB_PREFIX + entity] = emb
def get_embeddings(self, entity: str, side: Side) -> nn.Parameter: embs = side.pick(self.lhs_embs, self.rhs_embs) try: emb = embs[self.EMB_PREFIX + entity] except KeyError: return None else: return emb.weight
def _can_acquire( self, rank: Rank, part: Partition, locked_entities_parts: Dict[Tuple[EntityName, Partition], Rank], side: Side, ) -> bool: for entity in side.pick(self.entities_lhs, self.entities_rhs): if locked_entities_parts.get((entity, part), rank) != rank: return False return True
def set_embeddings(self, entity: str, weights: nn.Parameter, side: Side, shuffle_mode='all', shuffle_size=1, shuffle_order=2): if self.entities[entity].featurized: emb = FeaturizedEmbedding(weights, max_norm=self.max_norm) else: emb = SimpleEmbedding(weights, max_norm=self.max_norm, shuffle_mode=shuffle_mode, shuffle_size=shuffle_size, shuffle_order=shuffle_order) side.pick(self.lhs_embs, self.rhs_embs)[self.EMB_PREFIX + entity] = emb
def clear_embeddings(self, entity: str, side: Side) -> None: embs = side.pick(self.lhs_embs, self.rhs_embs) try: del embs[self.EMB_PREFIX + entity] except KeyError: pass