def weightedCoordLoss(fracs, r_energy, coords): ''' fracs: B x V x F energies: B x V coords: B x V x C returns: B x V ''' from caloGraphNN import euclidean_squared mask = tf.where(r_energy > 0, tf.zeros_like(r_energy) + 1., tf.zeros_like(r_energy)) #r_energy = tf.expand_dims(r_energy, axis=2) distances = euclidean_squared(coords, coords) # B x V x V fracdiff = euclidean_squared(fracs, fracs) # B x V x V #fracdiff = tf.where(fracdiff<0.5, tf.zeros_like(fracdiff), fracdiff) #distances = tf.where(fracdiff<0.5, tf.zeros_like(distances), distances) diffsq = (distances - fracdiff)**2 diffsq = tf.where(tf.logical_and(fracdiff > 0.5, distances > 0.5), tf.zeros_like(diffsq), diffsq) #if fracdiff large, distance should be large #fracdiff is max 1. distances are order 1 weighted = mask * tf.reduce_sum(diffsq, axis=2) return weighted
def collect_neighbours(self, coordinates, x, mask): # tf.ragged FIXME? # for euclidean_squared see caloGraphNN.py distance_matrix = euclidean_squared(coordinates, coordinates) ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix, self.n_neighbours) neighbour_indices = ranked_indices[:, :, 1:] features = self.input_feature_transform(x) n_batches = tf.shape(features)[0] # tf.ragged FIXME? or could that work? n_vertices = tf.shape(features)[1] n_features = tf.shape(features)[2] batch_range = tf.range(0, n_batches) batch_range = tf.expand_dims(batch_range, axis=1) batch_range = tf.expand_dims(batch_range, axis=1) batch_range = tf.expand_dims(batch_range, axis=1) # (B, 1, 1, 1) # tf.ragged FIXME? n_vertices batch_indices = tf.tile( batch_range, [1, n_vertices, self.n_neighbours - 1, 1]) # (B, V, N-1, 1) vertex_indices = tf.expand_dims(neighbour_indices, axis=3) # (B, V, N-1, 1) indices = tf.concat([batch_indices, vertex_indices], axis=-1) distance = -ranked_distances[:, :, 1:] weights = gauss_of_lin(distance * 10.) weights = tf.expand_dims(weights, axis=-1) for i in range(len(self.message_passing_layers) + 1): if i: features = self.message_passing_layers[i - 1](tf.concat( [features, x], axis=-1)) w = self.message_parsing_distance_weights[i - 1] weights = gauss_of_lin(w * distance) weights = tf.expand_dims(weights, axis=-1) if self.feature_dropout > 0 and self.feature_dropout < 1: features = keras.layers.Dropout(self.feature_dropout)(features) neighbour_features = tf.gather_nd(features, indices) # (B, V, N-1, F) # weight the neighbour_features neighbour_features *= weights neighbours_max = tf.reduce_max(neighbour_features, axis=2) neighbours_mean = tf.reduce_mean(neighbour_features, axis=2) features = tf.concat([neighbours_max, neighbours_mean], axis=-1) if mask is not None: features *= mask return features
def collect_neighbours_fullmatrix(self, coordinates, features): # implementation changed wrt caloGraphNN to account for batch size (B) being unknown (None) # V = number of vertices # N = number of neighbours # F = number of features per vertex # distance_matrix is the actual (B, V, V) matrix distance_matrix = euclidean_squared(coordinates, coordinates) _, ranked_indices = tf.nn.top_k(-distance_matrix, self.n_neighbours) neighbour_indices = ranked_indices[:, :, 1:] n_vertices = tf.shape(features)[1] n_features = tf.shape(features)[2] # make a boolean mask of the neighbours (B, V, N-1) neighbour_mask = tf.one_hot(neighbour_indices, depth=n_vertices, axis=-1, dtype=tf.int32) neighbour_mask = tf.reduce_sum(neighbour_mask, axis=2) neighbour_mask = tf.cast(neighbour_mask, tf.bool) # (B, V, F) -[tile]> (B, V, V, F) -[mask]> (B, V, N-1, F) neighbour_features = tf.expand_dims(features, axis=1) neighbour_features = tf.tile(neighbour_features, [1, n_vertices, 1, 1]) neighbour_features = tf.boolean_mask(neighbour_features, neighbour_mask) neighbour_features = tf.reshape( neighbour_features, [-1, n_vertices, self.n_neighbours - 1, n_features]) # (B, V, V) -[mask]> (B, V, N-1) distance = tf.boolean_mask(distance_matrix, neighbour_mask) distance = tf.reshape(distance, [-1, n_vertices, self.n_neighbours - 1]) weights = gauss_of_lin(distance * 10.) weights = tf.expand_dims(weights, axis=-1) # weight the neighbour_features neighbour_features *= weights neighbours_max = tf.reduce_max(neighbour_features, axis=2) neighbours_mean = tf.reduce_mean(neighbour_features, axis=2) return tf.concat([neighbours_max, neighbours_mean], axis=-1)
def collect_neighbours(self, coordinates, features): import tensorflow as tf from caloGraphNN import euclidean_squared, gauss_of_lin # tf.ragged FIXME? # for euclidean_squared see caloGraphNN.py distance_matrix = euclidean_squared(coordinates, coordinates) ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix, self.n_neighbours) neighbour_indices = ranked_indices[:, :, 1:] n_batches = tf.shape(features)[0] # tf.ragged FIXME? or could that work? n_vertices = K.shape(features)[1] n_features = K.shape(features)[2] batch_range = K.arange(n_batches) batch_range = K.expand_dims(batch_range, axis=1) batch_range = K.expand_dims(batch_range, axis=1) batch_range = K.expand_dims(batch_range, axis=1) # (B, 1, 1, 1) # tf.ragged FIXME? n_vertices batch_indices = K.tile( batch_range, [1, n_vertices, self.n_neighbours - 1, 1]) # (B, V, N-1, 1) vertex_indices = K.expand_dims(neighbour_indices, axis=3) # (B, V, N-1, 1) indices = K.concatenate([batch_indices, vertex_indices], axis=-1) neighbour_features = tf.gather_nd(features, indices) # (B, V, N-1, F) distance = -ranked_distances[:, :, 1:] weights = gauss_of_lin(distance * 10.) weights = K.expand_dims(weights, axis=-1) # weight the neighbour_features neighbour_features *= weights neighbours_max = K.max(neighbour_features, axis=2) neighbours_mean = K.mean(neighbour_features, axis=2) return K.concatenate([neighbours_max, neighbours_mean], axis=-1)
def collect_neighbours(self, coordinates, features): # V = number of vertices # N = number of neighbours # F = number of features per vertex distance_matrix = euclidean_squared(coordinates, coordinates) ranked_distances, ranked_indices = tf.nn.top_k(-distance_matrix, self.n_neighbours) neighbour_indices = ranked_indices[:, :, 1:] n_batches = tf.shape(features)[0] n_vertices = tf.shape(features)[1] n_features = tf.shape(features)[2] batch_range = tf.range(0, n_batches) batch_range = tf.expand_dims(batch_range, axis=1) batch_range = tf.expand_dims(batch_range, axis=1) batch_range = tf.expand_dims(batch_range, axis=1) # (B, 1, 1, 1) batch_indices = tf.tile( batch_range, [1, n_vertices, self.n_neighbours - 1, 1]) # (B, V, N-1, 1) vertex_indices = tf.expand_dims(neighbour_indices, axis=3) # (B, V, N-1, 1) indices = tf.concat([batch_indices, vertex_indices], axis=-1) neighbour_features = tf.gather_nd(features, indices) # (B, V, N-1, F) distance = ranked_distances[:, :, 1:] weights = gauss_of_lin(distance * 10.) weights = tf.expand_dims(weights, axis=-1) # weight the neighbour_features neighbour_features *= weights neighbours_max = tf.reduce_max(neighbour_features, axis=2) neighbours_mean = tf.reduce_mean(neighbour_features, axis=2) return tf.concat([neighbours_max, neighbours_mean], axis=-1)