コード例 #1
0
    def call(self, inputs):
        nodes, adjs = euler_ops.get_multi_hop_neighbor(inputs, self.metapath)
        hidden = [self.node_encoder(node) for node in nodes]
        h_t = [self.depth_fc[0](hidden[0])]
        for layer in range(self.num_layers):
            aggregator = self.aggregators[layer]
            next_hidden = []
            for hop in range(self.num_layers - layer):
                if self.use_residual:
                    h = hidden[hop] + \
                        aggregator((hidden[hop], hidden[hop + 1], adjs[hop]))
                else:
                    h = aggregator((hidden[hop], hidden[hop + 1], adjs[hop]))
                next_hidden.append(h)
            hidden = next_hidden
            h_t.append(self.depth_fc[layer+1](hidden[0]))

        lstm_cell = tf.nn.rnn_cell.LSTMCell(self.dim)
        initial_state = \
            lstm_cell.zero_state(tf.shape(inputs)[0], dtype=tf.float32)
        h_t = tf.concat([tf.reshape(i, [tf.shape(i)[0], 1, self.dim])
                         for i in h_t], 1)
        outputs, _ = tf.nn.dynamic_rnn(lstm_cell, h_t,
                                       initial_state=initial_state,
                                       dtype=tf.float32)
        outputs = tf.reshape(outputs[:, 0, :], [-1, outputs.shape[2]])
        output_shape = inputs.shape.concatenate(outputs.shape[-1])
        output_shape = [d if d is not None else -1
                        for d in output_shape.as_list()]
        return tf.reshape(outputs, output_shape)
コード例 #2
0
    def call(self, inputs):
        nodes, adjs = euler_ops.get_multi_hop_neighbor(inputs, self.metapath)
        hidden = [self.node_encoder(node) for node in nodes]
        for layer in range(self.num_layers):
            aggregator = self.aggregators[layer]
            next_hidden = []
            for hop in range(self.num_layers - layer):
                if self.use_residual:
                    h = hidden[hop] + \
                        aggregator((hidden[hop], hidden[hop + 1], adjs[hop]))
                else:
                    h = aggregator((hidden[hop], hidden[hop + 1], adjs[hop]))
                next_hidden.append(h)
            hidden = next_hidden

        output_shape = inputs.shape.concatenate(hidden[0].shape[-1])
        output_shape = [d if d is not None else -1
                        for d in output_shape.as_list()]
        return tf.reshape(hidden[0], output_shape)
コード例 #3
0
    def call(self, inputs, training=None):
        if not training:
            return super(ScalableGCNEncoder, self).call(inputs)

        (node, neighbor), (adj,) = \
            euler_ops.get_multi_hop_neighbor(inputs, [self.edge_type])
        node_embedding = self.node_encoder(node)
        neigh_embedding = self.node_encoder(neighbor)

        node_embeddings = []
        neigh_embeddings = []
        for layer in range(self.num_layers):
            aggregator = self.aggregators[layer]

            if self.use_residual:
                node_embedding += aggregator((node_embedding,
                                              neigh_embedding,
                                              adj))
            else:
                node_embedding = aggregator((node_embedding,
                                             neigh_embedding,
                                             adj))
            node_embeddings.append(node_embedding)

            if layer < self.num_layers - 1:
                neigh_embedding = \
                    tf.nn.embedding_lookup(self.stores[layer], neighbor)
                neigh_embeddings.append(neigh_embedding)

        self.update_store_op = self._update_store(node, node_embeddings)
        store_loss, self.optimize_store_op = \
            self._optimize_store(node, node_embeddings)
        self.get_update_gradient_op = lambda loss: \
            self._update_gradient(loss + store_loss,
                                  neighbor,
                                  neigh_embeddings)

        output_shape = inputs.shape.concatenate(node_embedding.shape[-1])
        output_shape = [d if d is not None else -1
                        for d in output_shape.as_list()]
        return tf.reshape(node_embedding, output_shape)