def _update_connections(self, knowledge: BaseKnowledge): reverse_vocab = self._reverse_vocab(knowledge.get_extended_vocab()) for label, idx in knowledge.get_vocab().items(): original_connections = knowledge.get_connections_for_idx(idx) connections = [ self.extended_vocab[reverse_vocab[x]] for x in original_connections ] self.connections_per_idx[self.vocab[label]].update(connections)
def _initialize_connections_from_knowledge(self, knowledge: BaseKnowledge): self.num_connections = 0 self.reverse_connections: Dict[int, Set[int]] = {} self.connections: Dict[int, Set[int]] = {} for _, idx in knowledge.get_vocab().items(): connections = knowledge.get_connections_for_idx(idx) self.connections[idx] = connections for connected_idx in connections: self.num_connections += 1 if idx == connected_idx: continue if connected_idx not in self.reverse_connections: self.reverse_connections[connected_idx] = set() self.reverse_connections[connected_idx].add(idx)
def _init_connection_information(self, knowledge: BaseKnowledge): logging.info("Initializing %s connection information", self.embedding_name) self.connections: Dict[int, List[int]] = {} self.connection_partition: List[ int ] = [] # connection_partition[i] = j -> {connection i relevant for j} for idx in tqdm( range(self.num_features), desc="Initializing {} connections".format(self.embedding_name), ): connected_idxs = knowledge.get_connections_for_idx(idx) self.connections[idx] = sorted(list(connected_idxs)) self.connection_partition = self.connection_partition + [idx] * len( connected_idxs ) self.connection_indices = [ v for _, v in sorted(self.connections.items(), key=lambda x: x[0]) ] # connection_indices[i,j] = k -> feature i is connected to feature k self.flattened_connection_indices = [ x for sublist in self.connection_indices for x in sublist ] # connection k is between connection_partition[k] and flattened_connection_indices[k] # connection_indices[i,j] = k -> connection_partition[l]=i, flattened_connection_indices[l]=k self.num_connections = len(self.flattened_connection_indices)
def __init__(self, knowledge: BaseKnowledge, config: ModelConfig): super(SimpleEmbedding, self).__init__() self.config = config self.num_features = len(knowledge.get_vocab()) self.num_hidden_features = 0 self.num_connections = 0 self._init_basic_embedding_variables(knowledge)
def _init_basic_embedding_variables(self, knowledge: BaseKnowledge): logging.info("Initializing %s basic embedding variables", self.embedding_name) self.basic_feature_embeddings = self.add_weight( initializer=self._get_feature_initializer( knowledge.get_description_vocab(set(knowledge.get_vocab().values())) ), trainable=self.config.base_feature_embeddings_trainable, name="{}/basic_feature_embeddings".format(self.embedding_name), shape=(self.num_features, self.config.embedding_dim), regularizer=super()._get_kernel_regularizer(scope="base_embeddings"), ) self.basic_hidden_embeddings = self.add_weight( initializer=self._get_hidden_initializer( knowledge.get_description_vocab(set(knowledge.get_vocab().values())) ), trainable=self.config.base_hidden_embeddings_trainable, name="{}/basic_hidden_embeddings".format(self.embedding_name), shape=(self.num_hidden_features, self.config.embedding_dim), regularizer=super()._get_kernel_regularizer(scope="base_embeddings"), )
def _init_basic_embedding_variables(self, knowledge: BaseKnowledge): logging.info("Initializing SIMPLE basic embedding variables") self.basic_feature_embeddings = self.add_weight( initializer=self._get_feature_initializer( {idx: name for name, idx in knowledge.get_vocab().items()}), trainable=self.config.base_feature_embeddings_trainable, name="simple_embedding/basic_feature_embeddings", shape=(self.num_features, self.config.embedding_dim), regularizer=super()._get_kernel_regularizer( scope="base_embeddings"), ) self.basic_hidden_embeddings = None
def __init__( self, knowledge: BaseKnowledge, config: ModelConfig, embedding_name: str ): super(KnowledgeEmbedding, self).__init__() self.embedding_name = embedding_name self.config = config self.num_features = len(knowledge.get_vocab()) self.num_hidden_features = len(knowledge.get_extended_vocab()) - len( knowledge.get_vocab() ) self.w = tf.keras.layers.Dense( self.config.attention_dim, use_bias=True, activation="tanh", kernel_regularizer=super()._get_kernel_regularizer(scope="attention"), ) self.u = tf.keras.layers.Dense( 1, use_bias=False, kernel_regularizer=super()._get_kernel_regularizer(scope="attention") ) self._init_basic_embedding_variables(knowledge) self._init_connection_information(knowledge)
def add_knowledge(self, knowledge: BaseKnowledge): self._add_vocab(knowledge.get_vocab()) self._add_extended_vocab(knowledge.get_extended_vocab()) self._update_connections(knowledge)