Beispiel #1
0
def get_partitioned_types(
        config: ConfigSchema,
        side: Side) -> Tuple[int, Set[EntityName], Set[EntityName]]:
    """Return the number of partitions on a given side and the entity types

    Each of the entity types that appear on the given side (LHS or RHS) of a relation
    type is split into some number of partitions. The ones that are split into one
    partition are called "unpartitioned" and behave as if all of their entities
    belonged to all buckets. The other ones are the "properly" partitioned ones.
    Currently, they must all be partitioned into the same number of partitions. This
    function returns that number, the names of the unpartitioned entity types and the
    names of the properly partitioned entity types.

    """
    entity_names_by_num_parts: Dict[int, Set[EntityName]] = defaultdict(set)
    for relation_config in config.relations:
        entity_name = side.pick(relation_config.lhs, relation_config.rhs)
        entity_config = config.entities[entity_name]
        entity_names_by_num_parts[entity_config.num_partitions].add(
            entity_name)

    unpartitioned_entity_names = entity_names_by_num_parts.pop(1, set())

    if len(entity_names_by_num_parts) == 0:
        return 1, unpartitioned_entity_names, set()
    if len(entity_names_by_num_parts) > 1:
        raise RuntimeError("Currently num_partitions must be a single "
                           "value across all partitioned entities.")

    (num_partitions,
     partitioned_entity_names), = entity_names_by_num_parts.items()
    return num_partitions, unpartitioned_entity_names, partitioned_entity_names
Beispiel #2
0
 def set_embeddings(self, entity: str, side: Side,
                    weights: nn.Parameter) -> None:
     if self.entities[entity].featurized:
         emb = FeaturizedEmbedding(weights, max_norm=self.max_norm)
     else:
         emb = SimpleEmbedding(weights, max_norm=self.max_norm)
     side.pick(self.lhs_embs, self.rhs_embs)[self.EMB_PREFIX + entity] = emb
Beispiel #3
0
 def get_embeddings(self, entity: str, side: Side) -> nn.Parameter:
     embs = side.pick(self.lhs_embs, self.rhs_embs)
     try:
         emb = embs[self.EMB_PREFIX + entity]
     except KeyError:
         return None
     else:
         return emb.weight
 def _can_acquire(
     self,
     rank: Rank,
     part: Partition,
     locked_entities_parts: Dict[Tuple[EntityName, Partition], Rank],
     side: Side,
 ) -> bool:
     for entity in side.pick(self.entities_lhs, self.entities_rhs):
         if locked_entities_parts.get((entity, part), rank) != rank:
             return False
     return True
Beispiel #5
0
 def set_embeddings(self,
                    entity: str,
                    weights: nn.Parameter,
                    side: Side,
                    shuffle_mode='all',
                    shuffle_size=1,
                    shuffle_order=2):
     if self.entities[entity].featurized:
         emb = FeaturizedEmbedding(weights, max_norm=self.max_norm)
     else:
         emb = SimpleEmbedding(weights,
                               max_norm=self.max_norm,
                               shuffle_mode=shuffle_mode,
                               shuffle_size=shuffle_size,
                               shuffle_order=shuffle_order)
     side.pick(self.lhs_embs, self.rhs_embs)[self.EMB_PREFIX + entity] = emb
Beispiel #6
0
 def clear_embeddings(self, entity: str, side: Side) -> None:
     embs = side.pick(self.lhs_embs, self.rhs_embs)
     try:
         del embs[self.EMB_PREFIX + entity]
     except KeyError:
         pass