def compute_log_probs(probs, labels): # Select arbitrary element for unused arguments (log probs will be masked) labels = tf.maximum(labels, 0) indices = tf.stack([tf.range(tf.shape(labels)[0]), labels], axis=1) return safe_log(tf.gather_nd(probs, indices)) # TODO tf.log should suffice
def compute_entropy(probs): return -tf.reduce_sum(safe_log(probs) * probs, axis=-1)