def nn_loss(self, reference, target, neighborhood_size=(3, 3)): v_pad = neighborhood_size[0] // 2 h_pad = neighborhood_size[1] // 2 val_pad = ktf.pad(reference, [[0, 0], [v_pad, v_pad], [h_pad, h_pad], [0, 0]], mode='CONSTANT', constant_values=-10000) reference_tensors = [] for i_begin in range(0, neighborhood_size[0]): i_end = i_begin - neighborhood_size[0] + 1 i_end = None if i_end == 0 else i_end for j_begin in range(0, neighborhood_size[1]): j_end = j_begin - neighborhood_size[0] + 1 j_end = None if j_end == 0 else j_end sub_tensor = val_pad[:, i_begin:i_end, j_begin:j_end, :] reference_tensors.append(ktf.expand_dims(sub_tensor, -1)) reference = ktf.concat(reference_tensors, axis=-1) target = ktf.expand_dims(target, axis=-1) abs = ktf.abs(reference - target) norms = ktf.reduce_sum(abs, reduction_indices=[-2]) loss = ktf.reduce_min(norms, reduction_indices=[-1]) return loss
def top_k(scores, I, ratio, top_k_var): """ Returns indices to get the top K values in `scores` segment-wise, with segments defined by I. K is not fixed, but it is defined as a ratio of the number of elements in each segment. :param scores: a rank 1 tensor with scores; :param I: a rank 1 tensor with segment IDs; :param ratio: float, ratio of elements to keep for each segment; :param top_k_var: a tf.Variable without shape validation (e.g., `tf.Variable(0.0, validate_shape=False)`); :return: a rank 1 tensor containing the indices to get the top K values of each segment in `scores`. """ num_nodes = tf.segment_sum(tf.ones_like(I), I) # Number of nodes in each graph cumsum = tf.cumsum(num_nodes) # Cumulative number of nodes (A, A+B, A+B+C) cumsum_start = cumsum - num_nodes # Start index of each graph n_graphs = tf.shape(num_nodes)[0] # Number of graphs in batch max_n_nodes = tf.reduce_max(num_nodes) # Order of biggest graph in batch batch_n_nodes = tf.shape(I)[0] # Number of overall nodes in batch to_keep = tf.ceil(ratio * tf.cast(num_nodes, tf.float32)) to_keep = tf.cast(to_keep, tf.int32) # Nodes to keep in each graph index = tf.range(batch_n_nodes) index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes) y_min = tf.reduce_min(scores) dense_y = tf.ones((n_graphs * max_n_nodes, )) dense_y = dense_y * tf.cast( y_min - 1, tf.float32 ) # subtract 1 to ensure that filler values do not get picked dense_y = tf.assign( top_k_var, dense_y, validate_shape=False ) # top_k_var is a variable with unknown shape defined in the elsewhere dense_y = tf.scatter_update(dense_y, index, scores) dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes)) perm = tf.argsort(dense_y, direction='DESCENDING') perm = perm + cumsum_start[:, None] perm = tf.reshape(perm, (-1, )) to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs, )) rep_times = tf.reshape( tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1), (-1, )) mask = tf_repeat_1d(to_rep, rep_times) perm = tf.boolean_mask(perm, mask) return perm