Exemplo n.º 1
0
    def collect_neighbours(self, coordinates, x, mask):

        # tf.ragged FIXME?
        # for euclidean_squared see caloGraphNN.py
        distance_matrix = euclidean_squared(coordinates, coordinates)
        ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix,
                                                       self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        features = self.input_feature_transform(x)

        n_batches = tf.shape(features)[0]

        # tf.ragged FIXME? or could that work?
        n_vertices = tf.shape(features)[1]
        n_features = tf.shape(features)[2]

        batch_range = tf.range(0, n_batches)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)  # (B, 1, 1, 1)

        # tf.ragged FIXME? n_vertices
        batch_indices = tf.tile(
            batch_range,
            [1, n_vertices, self.n_neighbours - 1, 1])  # (B, V, N-1, 1)
        vertex_indices = tf.expand_dims(neighbour_indices,
                                        axis=3)  # (B, V, N-1, 1)
        indices = tf.concat([batch_indices, vertex_indices], axis=-1)

        distance = -ranked_distances[:, :, 1:]

        weights = gauss_of_lin(distance * 10.)
        weights = tf.expand_dims(weights, axis=-1)

        for i in range(len(self.message_passing_layers) + 1):
            if i:
                features = self.message_passing_layers[i - 1](tf.concat(
                    [features, x], axis=-1))
                w = self.message_parsing_distance_weights[i - 1]
                weights = gauss_of_lin(w * distance)
                weights = tf.expand_dims(weights, axis=-1)

            if self.feature_dropout > 0 and self.feature_dropout < 1:
                features = keras.layers.Dropout(self.feature_dropout)(features)

            neighbour_features = tf.gather_nd(features,
                                              indices)  # (B, V, N-1, F)
            # weight the neighbour_features
            neighbour_features *= weights

            neighbours_max = tf.reduce_max(neighbour_features, axis=2)
            neighbours_mean = tf.reduce_mean(neighbour_features, axis=2)

            features = tf.concat([neighbours_max, neighbours_mean], axis=-1)
            if mask is not None:
                features *= mask

        return features
Exemplo n.º 2
0
    def collect_neighbours(self, features, neighbour_indices, distancesq):

        weights = gauss_of_lin(10. * distancesq)
        weights = tf.expand_dims(weights, axis=-1)  # [SV, N, 1]
        neighbour_features = tf.gather_nd(features, neighbour_indices)
        neighbour_features *= weights
        neighbours_max = tf.reduce_max(neighbour_features, axis=1)
        neighbours_mean = tf.reduce_mean(neighbour_features, axis=1)

        return tf.concat([neighbours_max, neighbours_mean], axis=-1)
Exemplo n.º 3
0
    def collect_neighbours_fullmatrix(self, coordinates, features):
        # implementation changed wrt caloGraphNN to account for batch size (B) being unknown (None)
        # V = number of vertices
        # N = number of neighbours
        # F = number of features per vertex

        # distance_matrix is the actual (B, V, V) matrix
        distance_matrix = euclidean_squared(coordinates, coordinates)
        _, ranked_indices = tf.nn.top_k(-distance_matrix, self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        n_vertices = tf.shape(features)[1]
        n_features = tf.shape(features)[2]

        # make a boolean mask of the neighbours (B, V, N-1)
        neighbour_mask = tf.one_hot(neighbour_indices,
                                    depth=n_vertices,
                                    axis=-1,
                                    dtype=tf.int32)
        neighbour_mask = tf.reduce_sum(neighbour_mask, axis=2)
        neighbour_mask = tf.cast(neighbour_mask, tf.bool)

        # (B, V, F) -[tile]> (B, V, V, F) -[mask]> (B, V, N-1, F)
        neighbour_features = tf.expand_dims(features, axis=1)
        neighbour_features = tf.tile(neighbour_features, [1, n_vertices, 1, 1])
        neighbour_features = tf.boolean_mask(neighbour_features,
                                             neighbour_mask)
        neighbour_features = tf.reshape(
            neighbour_features,
            [-1, n_vertices, self.n_neighbours - 1, n_features])

        # (B, V, V) -[mask]> (B, V, N-1)
        distance = tf.boolean_mask(distance_matrix, neighbour_mask)
        distance = tf.reshape(distance,
                              [-1, n_vertices, self.n_neighbours - 1])

        weights = gauss_of_lin(distance * 10.)
        weights = tf.expand_dims(weights, axis=-1)

        # weight the neighbour_features
        neighbour_features *= weights

        neighbours_max = tf.reduce_max(neighbour_features, axis=2)
        neighbours_mean = tf.reduce_mean(neighbour_features, axis=2)

        return tf.concat([neighbours_max, neighbours_mean], axis=-1)
    def collect_neighbours(self, coordinates, features):
        import tensorflow as tf
        from caloGraphNN import euclidean_squared, gauss_of_lin

        # tf.ragged FIXME?
        # for euclidean_squared see caloGraphNN.py
        distance_matrix = euclidean_squared(coordinates, coordinates)

        ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix,
                                                       self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        n_batches = tf.shape(features)[0]

        # tf.ragged FIXME? or could that work?
        n_vertices = K.shape(features)[1]
        n_features = K.shape(features)[2]

        batch_range = K.arange(n_batches)
        batch_range = K.expand_dims(batch_range, axis=1)
        batch_range = K.expand_dims(batch_range, axis=1)
        batch_range = K.expand_dims(batch_range, axis=1)  # (B, 1, 1, 1)

        # tf.ragged FIXME? n_vertices
        batch_indices = K.tile(
            batch_range,
            [1, n_vertices, self.n_neighbours - 1, 1])  # (B, V, N-1, 1)
        vertex_indices = K.expand_dims(neighbour_indices,
                                       axis=3)  # (B, V, N-1, 1)
        indices = K.concatenate([batch_indices, vertex_indices], axis=-1)

        neighbour_features = tf.gather_nd(features, indices)  # (B, V, N-1, F)

        distance = -ranked_distances[:, :, 1:]

        weights = gauss_of_lin(distance * 10.)
        weights = K.expand_dims(weights, axis=-1)

        # weight the neighbour_features
        neighbour_features *= weights

        neighbours_max = K.max(neighbour_features, axis=2)
        neighbours_mean = K.mean(neighbour_features, axis=2)

        return K.concatenate([neighbours_max, neighbours_mean], axis=-1)
Exemplo n.º 5
0
    def collect_neighbours(self, coordinates, features):
        # V = number of vertices
        # N = number of neighbours
        # F = number of features per vertex

        distance_matrix = euclidean_squared(coordinates, coordinates)
        ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix,
                                                       self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        n_batches = tf.shape(features)[0]
        n_vertices = tf.shape(features)[1]
        n_features = tf.shape(features)[2]

        batch_range = tf.range(0, n_batches)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)  # (B, 1, 1, 1)

        batch_indices = tf.tile(
            batch_range,
            [1, n_vertices, self.n_neighbours - 1, 1])  # (B, V, N-1, 1)
        vertex_indices = tf.expand_dims(neighbour_indices,
                                        axis=3)  # (B, V, N-1, 1)
        indices = tf.concat([batch_indices, vertex_indices], axis=-1)

        neighbour_features = tf.gather_nd(features, indices)  # (B, V, N-1, F)

        distance = ranked_distances[:, :, 1:]

        weights = gauss_of_lin(distance * 10.)
        weights = tf.expand_dims(weights, axis=-1)

        # weight the neighbour_features
        neighbour_features *= weights

        neighbours_max = tf.reduce_max(neighbour_features, axis=2)
        neighbours_mean = tf.reduce_mean(neighbour_features, axis=2)

        return tf.concat([neighbours_max, neighbours_mean], axis=-1)