def _update_connections(self, knowledge: BaseKnowledge):
     reverse_vocab = self._reverse_vocab(knowledge.get_extended_vocab())
     for label, idx in knowledge.get_vocab().items():
         original_connections = knowledge.get_connections_for_idx(idx)
         connections = [
             self.extended_vocab[reverse_vocab[x]]
             for x in original_connections
         ]
         self.connections_per_idx[self.vocab[label]].update(connections)
Beispiel #2
0
    def _initialize_connections_from_knowledge(self, knowledge: BaseKnowledge):
        self.num_connections = 0
        self.reverse_connections: Dict[int, Set[int]] = {}
        self.connections: Dict[int, Set[int]] = {}
        for _, idx in knowledge.get_vocab().items():
            connections = knowledge.get_connections_for_idx(idx)
            self.connections[idx] = connections
            for connected_idx in connections:
                self.num_connections += 1
                if idx == connected_idx:
                    continue

                if connected_idx not in self.reverse_connections:
                    self.reverse_connections[connected_idx] = set()
                self.reverse_connections[connected_idx].add(idx)
Beispiel #3
0
    def _init_connection_information(self, knowledge: BaseKnowledge):
        logging.info("Initializing %s connection information", self.embedding_name)
        self.connections: Dict[int, List[int]] = {}
        self.connection_partition: List[
            int
        ] = []  # connection_partition[i] = j -> {connection i relevant for j}

        for idx in tqdm(
            range(self.num_features),
            desc="Initializing {} connections".format(self.embedding_name),
        ):
            connected_idxs = knowledge.get_connections_for_idx(idx)
            self.connections[idx] = sorted(list(connected_idxs))
            self.connection_partition = self.connection_partition + [idx] * len(
                connected_idxs
            )

        self.connection_indices = [
            v for _, v in sorted(self.connections.items(), key=lambda x: x[0])
        ]  # connection_indices[i,j] = k -> feature i is connected to feature k
        self.flattened_connection_indices = [
            x for sublist in self.connection_indices for x in sublist
        ]  # connection k is between connection_partition[k] and flattened_connection_indices[k]
        # connection_indices[i,j] = k -> connection_partition[l]=i, flattened_connection_indices[l]=k
        self.num_connections = len(self.flattened_connection_indices)
Beispiel #4
0
    def __init__(self, knowledge: BaseKnowledge, config: ModelConfig):
        super(SimpleEmbedding, self).__init__()
        self.config = config

        self.num_features = len(knowledge.get_vocab())
        self.num_hidden_features = 0
        self.num_connections = 0

        self._init_basic_embedding_variables(knowledge)
Beispiel #5
0
 def _init_basic_embedding_variables(self, knowledge: BaseKnowledge):
     logging.info("Initializing %s basic embedding variables", self.embedding_name)
     self.basic_feature_embeddings = self.add_weight(
         initializer=self._get_feature_initializer(
             knowledge.get_description_vocab(set(knowledge.get_vocab().values()))
         ),
         trainable=self.config.base_feature_embeddings_trainable,
         name="{}/basic_feature_embeddings".format(self.embedding_name),
         shape=(self.num_features, self.config.embedding_dim),
         regularizer=super()._get_kernel_regularizer(scope="base_embeddings"),
     )
     self.basic_hidden_embeddings = self.add_weight(
         initializer=self._get_hidden_initializer(
             knowledge.get_description_vocab(set(knowledge.get_vocab().values()))
         ),
         trainable=self.config.base_hidden_embeddings_trainable,
         name="{}/basic_hidden_embeddings".format(self.embedding_name),
         shape=(self.num_hidden_features, self.config.embedding_dim),
         regularizer=super()._get_kernel_regularizer(scope="base_embeddings"),
     )
Beispiel #6
0
 def _init_basic_embedding_variables(self, knowledge: BaseKnowledge):
     logging.info("Initializing SIMPLE basic embedding variables")
     self.basic_feature_embeddings = self.add_weight(
         initializer=self._get_feature_initializer(
             {idx: name
              for name, idx in knowledge.get_vocab().items()}),
         trainable=self.config.base_feature_embeddings_trainable,
         name="simple_embedding/basic_feature_embeddings",
         shape=(self.num_features, self.config.embedding_dim),
         regularizer=super()._get_kernel_regularizer(
             scope="base_embeddings"),
     )
     self.basic_hidden_embeddings = None
Beispiel #7
0
    def __init__(
        self, knowledge: BaseKnowledge, config: ModelConfig, embedding_name: str
    ):
        super(KnowledgeEmbedding, self).__init__()
        self.embedding_name = embedding_name
        self.config = config

        self.num_features = len(knowledge.get_vocab())
        self.num_hidden_features = len(knowledge.get_extended_vocab()) - len(
            knowledge.get_vocab()
        )

        self.w = tf.keras.layers.Dense(
            self.config.attention_dim,
            use_bias=True,
            activation="tanh",
            kernel_regularizer=super()._get_kernel_regularizer(scope="attention"),
        )
        self.u = tf.keras.layers.Dense(
            1, use_bias=False, kernel_regularizer=super()._get_kernel_regularizer(scope="attention")
        )

        self._init_basic_embedding_variables(knowledge)
        self._init_connection_information(knowledge)
 def add_knowledge(self, knowledge: BaseKnowledge):
     self._add_vocab(knowledge.get_vocab())
     self._add_extended_vocab(knowledge.get_extended_vocab())
     self._update_connections(knowledge)