Ejemplo n.º 1
0
    def call(self, inputs, mask=None):
        features, feature_graph_index = inputs
        feature_graph_index = tf.reshape(feature_graph_index, (-1,))
        _, _, count = tf.unique_with_counts(feature_graph_index)
        m = kb.dot(features, self.m_weight)
        if self.use_bias:
            m += self.m_bias

        self.h = tf.zeros(tf.stack(
            [tf.shape(input=features)[0], tf.shape(input=count)[0], self.n_hidden]))
        self.c = tf.zeros(tf.stack(
            [tf.shape(input=features)[0], tf.shape(input=count)[0], self.n_hidden]))
        q_star = tf.zeros(tf.stack(
            [tf.shape(input=features)[0], tf.shape(input=count)[0], 2 * self.n_hidden]))
        for i in range(self.T):
            self.h, c = self._lstm(q_star, self.c)
            e_i_t = tf.reduce_sum(
                input_tensor=m * repeat_with_index(self.h, feature_graph_index), axis=-1)
            exp = tf.exp(e_i_t)
            # print(exp.shape)
            seg_sum = tf.transpose(
                a=tf.math.segment_sum(
                    tf.transpose(a=exp, perm=[1, 0]),
                    feature_graph_index),
                perm=[1, 0])
            seg_sum = tf.expand_dims(seg_sum, axis=-1)
            # print(seg_sum.shape)
            a_i_t = exp / tf.squeeze(
                repeat_with_index(seg_sum, feature_graph_index))
            # print(a_i_t.shape)
            r_t = tf.transpose(a=tf.math.segment_sum(
                tf.transpose(a=tf.multiply(m, a_i_t[:, :, None]), perm=[1, 0, 2]),
                feature_graph_index), perm=[1, 0, 2])
            q_star = kb.concatenate([self.h, r_t], axis=-1)
        return q_star
Ejemplo n.º 2
0
    def call(self, inputs, mask=None):
        """
        Core logic of the layer

        Args:
            inputs (tuple): input tuple of length 3
            mask (tf.Tensor): not used here
        """
        features, weights, feature_graph_index = inputs
        feature_graph_index = tf.reshape(feature_graph_index, (-1, ))
        _, _, count = tf.unique_with_counts(feature_graph_index)
        m = kb.dot(features, self.m_weight)
        if self.use_bias:
            m += self.m_bias

        self.h = tf.zeros(
            tf.stack([
                tf.shape(input=features)[0],
                tf.shape(input=count)[0], self.n_hidden
            ]))
        self.c = tf.zeros(
            tf.stack([
                tf.shape(input=features)[0],
                tf.shape(input=count)[0], self.n_hidden
            ]))
        q_star = tf.zeros(
            tf.stack([
                tf.shape(input=features)[0],
                tf.shape(input=count)[0], 2 * self.n_hidden
            ]))
        for i in range(self.T):
            self.h, c = self._lstm(q_star, self.c)
            e_i_t = tf.reduce_sum(
                input_tensor=m *
                repeat_with_index(self.h, feature_graph_index),
                axis=-1)
            exp = tf.exp(e_i_t) * weights
            # print('exp shape ', exp.shape)
            seg_sum = tf.transpose(a=tf.math.segment_sum(
                tf.transpose(a=exp, perm=[1, 0]), feature_graph_index),
                                   perm=[1, 0])
            seg_sum = tf.expand_dims(seg_sum, axis=-1)
            # print('seg_sum shape', seg_sum.shape)
            interm = repeat_with_index(seg_sum, feature_graph_index)
            # print('interm shape', interm.shape)
            a_i_t = exp / interm[..., 0]
            # print(a_i_t.shape)
            r_t = tf.transpose(
                a=tf.math.segment_sum(
                    tf.transpose(a=tf.multiply(m, a_i_t[:, :, None]),
                                 perm=[1, 0, 2]), feature_graph_index),
                perm=[1, 0, 2],
            )
            q_star = kb.concatenate([self.h, r_t], axis=-1)
        return q_star
Ejemplo n.º 3
0
 def test_repeat_with_index(self):
     repeat_result = repeat_with_index(self.x, self.index, axis=1).numpy()
     self.assertListEqual(list(repeat_result.shape), [1, 6, 4])
     self.assertEqual(repeat_result[0, 0, 0], repeat_result[0, 1, 0])
     self.assertEqual(repeat_result[0, 0, 0], repeat_result[0, 2, 0])
     self.assertNotEqual(repeat_result[0, 0, 0], repeat_result[0, 3, 0])
     self.assertEqual(repeat_result[0, 3, 0], repeat_result[0, 4, 0])
Ejemplo n.º 4
0
 def phi_e(self, inputs):
     nodes, edges, u, index1, index2, gnode, gbond = inputs
     index1 = tf.reshape(index1, (-1, ))
     index2 = tf.reshape(index2, (-1, ))
     fs = tf.gather(nodes, index1, axis=1)
     fr = tf.gather(nodes, index2, axis=1)
     concate_node = tf.concat([fs, fr], axis=-1)
     u_expand = repeat_with_index(u, gbond, axis=1)
     concated = tf.concat([concate_node, edges, u_expand], axis=-1)
     return self._mlp(concated, self.phi_e_weights, self.phi_e_biases)
Ejemplo n.º 5
0
    def call(self, inputs, mask=None):
        """
        Main logic
        Args:
            inputs (tuple of tensor): input tensors
            mask (tensor): mask tensor

        Returns: output tensor

        """
        features, feature_graph_index = inputs
        feature_graph_index = tf.reshape(feature_graph_index, (-1,))
        _, _, count = tf.unique_with_counts(feature_graph_index)
        m = kb.dot(features, self.m_weight)
        if self.use_bias:
            m += self.m_bias

        self.h = tf.zeros(tf.stack([tf.shape(input=features)[0], tf.shape(input=count)[0], self.n_hidden]))
        self.c = tf.zeros(tf.stack([tf.shape(input=features)[0], tf.shape(input=count)[0], self.n_hidden]))
        q_star = tf.zeros(tf.stack([tf.shape(input=features)[0], tf.shape(input=count)[0], 2 * self.n_hidden]))
        for i in range(self.T):
            self.h, c = self._lstm(q_star, self.c)
            e_i_t = tf.reduce_sum(input_tensor=m * repeat_with_index(self.h, feature_graph_index), axis=-1)
            maxes = tf.math.segment_max(e_i_t[0], feature_graph_index)
            e_i_t -= tf.expand_dims(tf.gather(maxes, feature_graph_index, axis=0), axis=0)
            exp = tf.exp(e_i_t)
            seg_sum = tf.transpose(
                a=tf.math.segment_sum(tf.transpose(a=exp, perm=[1, 0]), feature_graph_index), perm=[1, 0]
            )
            seg_sum = tf.expand_dims(seg_sum, axis=-1)
            interm = repeat_with_index(seg_sum, feature_graph_index)
            a_i_t = exp / interm[..., 0]
            r_t = tf.transpose(
                a=tf.math.segment_sum(
                    tf.transpose(a=tf.multiply(m, a_i_t[:, :, None]), perm=[1, 0, 2]), feature_graph_index
                ),
                perm=[1, 0, 2],
            )
            q_star = kb.concatenate([self.h, r_t], axis=-1)
        return q_star
Ejemplo n.º 6
0
    def phi_v(self, b_ei_p, inputs):
        """
        Node update function
        Args:
            b_ei_p (tensor): edge aggregated tensor
            inputs (tuple of tensors): other graph inputs

        Returns: updated node tensor

        """
        nodes, edges, u, index1, index2, gnode, gbond = inputs
        u_expand = repeat_with_index(u, gnode, axis=1)
        concated = tf.concat([b_ei_p, nodes, u_expand], -1)
        return self._mlp(concated, self.phi_v_weights, self.phi_v_biases)
Ejemplo n.º 7
0
 def phi_e(self, inputs):
     """
     Edge update function
     Args:
         inputs (tuple of tensor)
     Returns:
         output tensor
     """
     nodes, edges, u, index1, index2, gnode, gbond = inputs
     index1 = tf.reshape(index1, (-1, ))
     index2 = tf.reshape(index2, (-1, ))
     fs = tf.gather(nodes, index1, axis=1)
     fr = tf.gather(nodes, index2, axis=1)
     concate_node = tf.concat([fs, fr], -1)
     u_expand = repeat_with_index(u, gbond, axis=1)
     concated = tf.concat([concate_node, edges, u_expand], -1)
     return self._mlp(concated, self.phi_e_weights, self.phi_e_biases)
Ejemplo n.º 8
0
 def phi_v(self, b_ei_p, inputs):
     nodes, edges, u, index1, index2, gnode, gbond = inputs
     u_expand = repeat_with_index(u, gnode, axis=1)
     concated = tf.concat([b_ei_p, nodes, u_expand], axis=-1)
     return self._mlp(concated, self.phi_v_weights, self.phi_v_biases)
Ejemplo n.º 9
0
 def phi_v(self, b_ei_p, inputs):
     node, weights, u, index1, index2, gnode, gbond = inputs
     u_expand = repeat_with_index(u, gnode, axis=1)
     # print(u_expand.shape, node.shape)
     concated = tf.concat([b_ei_p, node, u_expand], axis=-1)
     return self._mlp(concated, self.phi_v_weight, self.phi_v_bias)