Ejemplo n.º 1
0
def weightedCoordLoss(fracs, r_energy, coords):
    '''
    fracs:    B x V x F
    energies: B x V 
    coords:   B x V x C
    
    returns:  B x V
    '''
    from caloGraphNN import euclidean_squared

    mask = tf.where(r_energy > 0,
                    tf.zeros_like(r_energy) + 1., tf.zeros_like(r_energy))
    #r_energy = tf.expand_dims(r_energy, axis=2)

    distances = euclidean_squared(coords, coords)  # B x V x V
    fracdiff = euclidean_squared(fracs, fracs)  # B x V x V
    #fracdiff =  tf.where(fracdiff<0.5, tf.zeros_like(fracdiff), fracdiff)
    #distances = tf.where(fracdiff<0.5, tf.zeros_like(distances), distances)

    diffsq = (distances - fracdiff)**2
    diffsq = tf.where(tf.logical_and(fracdiff > 0.5, distances > 0.5),
                      tf.zeros_like(diffsq), diffsq)
    #if fracdiff large, distance should be large
    #fracdiff is max 1. distances are order 1
    weighted = mask * tf.reduce_sum(diffsq, axis=2)
    return weighted
Ejemplo n.º 2
0
    def collect_neighbours(self, coordinates, x, mask):

        # tf.ragged FIXME?
        # for euclidean_squared see caloGraphNN.py
        distance_matrix = euclidean_squared(coordinates, coordinates)
        ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix,
                                                       self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        features = self.input_feature_transform(x)

        n_batches = tf.shape(features)[0]

        # tf.ragged FIXME? or could that work?
        n_vertices = tf.shape(features)[1]
        n_features = tf.shape(features)[2]

        batch_range = tf.range(0, n_batches)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)  # (B, 1, 1, 1)

        # tf.ragged FIXME? n_vertices
        batch_indices = tf.tile(
            batch_range,
            [1, n_vertices, self.n_neighbours - 1, 1])  # (B, V, N-1, 1)
        vertex_indices = tf.expand_dims(neighbour_indices,
                                        axis=3)  # (B, V, N-1, 1)
        indices = tf.concat([batch_indices, vertex_indices], axis=-1)

        distance = -ranked_distances[:, :, 1:]

        weights = gauss_of_lin(distance * 10.)
        weights = tf.expand_dims(weights, axis=-1)

        for i in range(len(self.message_passing_layers) + 1):
            if i:
                features = self.message_passing_layers[i - 1](tf.concat(
                    [features, x], axis=-1))
                w = self.message_parsing_distance_weights[i - 1]
                weights = gauss_of_lin(w * distance)
                weights = tf.expand_dims(weights, axis=-1)

            if self.feature_dropout > 0 and self.feature_dropout < 1:
                features = keras.layers.Dropout(self.feature_dropout)(features)

            neighbour_features = tf.gather_nd(features,
                                              indices)  # (B, V, N-1, F)
            # weight the neighbour_features
            neighbour_features *= weights

            neighbours_max = tf.reduce_max(neighbour_features, axis=2)
            neighbours_mean = tf.reduce_mean(neighbour_features, axis=2)

            features = tf.concat([neighbours_max, neighbours_mean], axis=-1)
            if mask is not None:
                features *= mask

        return features
Ejemplo n.º 3
0
    def collect_neighbours_fullmatrix(self, coordinates, features):
        # implementation changed wrt caloGraphNN to account for batch size (B) being unknown (None)
        # V = number of vertices
        # N = number of neighbours
        # F = number of features per vertex

        # distance_matrix is the actual (B, V, V) matrix
        distance_matrix = euclidean_squared(coordinates, coordinates)
        _, ranked_indices = tf.nn.top_k(-distance_matrix, self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        n_vertices = tf.shape(features)[1]
        n_features = tf.shape(features)[2]

        # make a boolean mask of the neighbours (B, V, N-1)
        neighbour_mask = tf.one_hot(neighbour_indices,
                                    depth=n_vertices,
                                    axis=-1,
                                    dtype=tf.int32)
        neighbour_mask = tf.reduce_sum(neighbour_mask, axis=2)
        neighbour_mask = tf.cast(neighbour_mask, tf.bool)

        # (B, V, F) -[tile]> (B, V, V, F) -[mask]> (B, V, N-1, F)
        neighbour_features = tf.expand_dims(features, axis=1)
        neighbour_features = tf.tile(neighbour_features, [1, n_vertices, 1, 1])
        neighbour_features = tf.boolean_mask(neighbour_features,
                                             neighbour_mask)
        neighbour_features = tf.reshape(
            neighbour_features,
            [-1, n_vertices, self.n_neighbours - 1, n_features])

        # (B, V, V) -[mask]> (B, V, N-1)
        distance = tf.boolean_mask(distance_matrix, neighbour_mask)
        distance = tf.reshape(distance,
                              [-1, n_vertices, self.n_neighbours - 1])

        weights = gauss_of_lin(distance * 10.)
        weights = tf.expand_dims(weights, axis=-1)

        # weight the neighbour_features
        neighbour_features *= weights

        neighbours_max = tf.reduce_max(neighbour_features, axis=2)
        neighbours_mean = tf.reduce_mean(neighbour_features, axis=2)

        return tf.concat([neighbours_max, neighbours_mean], axis=-1)
    def collect_neighbours(self, coordinates, features):
        import tensorflow as tf
        from caloGraphNN import euclidean_squared, gauss_of_lin

        # tf.ragged FIXME?
        # for euclidean_squared see caloGraphNN.py
        distance_matrix = euclidean_squared(coordinates, coordinates)

        ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix,
                                                       self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        n_batches = tf.shape(features)[0]

        # tf.ragged FIXME? or could that work?
        n_vertices = K.shape(features)[1]
        n_features = K.shape(features)[2]

        batch_range = K.arange(n_batches)
        batch_range = K.expand_dims(batch_range, axis=1)
        batch_range = K.expand_dims(batch_range, axis=1)
        batch_range = K.expand_dims(batch_range, axis=1)  # (B, 1, 1, 1)

        # tf.ragged FIXME? n_vertices
        batch_indices = K.tile(
            batch_range,
            [1, n_vertices, self.n_neighbours - 1, 1])  # (B, V, N-1, 1)
        vertex_indices = K.expand_dims(neighbour_indices,
                                       axis=3)  # (B, V, N-1, 1)
        indices = K.concatenate([batch_indices, vertex_indices], axis=-1)

        neighbour_features = tf.gather_nd(features, indices)  # (B, V, N-1, F)

        distance = -ranked_distances[:, :, 1:]

        weights = gauss_of_lin(distance * 10.)
        weights = K.expand_dims(weights, axis=-1)

        # weight the neighbour_features
        neighbour_features *= weights

        neighbours_max = K.max(neighbour_features, axis=2)
        neighbours_mean = K.mean(neighbour_features, axis=2)

        return K.concatenate([neighbours_max, neighbours_mean], axis=-1)
Ejemplo n.º 5
0
    def collect_neighbours(self, coordinates, features):
        # V = number of vertices
        # N = number of neighbours
        # F = number of features per vertex

        distance_matrix = euclidean_squared(coordinates, coordinates)
        ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix,
                                                       self.n_neighbours)

        neighbour_indices = ranked_indices[:, :, 1:]

        n_batches = tf.shape(features)[0]
        n_vertices = tf.shape(features)[1]
        n_features = tf.shape(features)[2]

        batch_range = tf.range(0, n_batches)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)
        batch_range = tf.expand_dims(batch_range, axis=1)  # (B, 1, 1, 1)

        batch_indices = tf.tile(
            batch_range,
            [1, n_vertices, self.n_neighbours - 1, 1])  # (B, V, N-1, 1)
        vertex_indices = tf.expand_dims(neighbour_indices,
                                        axis=3)  # (B, V, N-1, 1)
        indices = tf.concat([batch_indices, vertex_indices], axis=-1)

        neighbour_features = tf.gather_nd(features, indices)  # (B, V, N-1, F)

        distance = ranked_distances[:, :, 1:]

        weights = gauss_of_lin(distance * 10.)
        weights = tf.expand_dims(weights, axis=-1)

        # weight the neighbour_features
        neighbour_features *= weights

        neighbours_max = tf.reduce_max(neighbour_features, axis=2)
        neighbours_mean = tf.reduce_mean(neighbour_features, axis=2)

        return tf.concat([neighbours_max, neighbours_mean], axis=-1)