def sample(self, noise=True): if noise: u = tf.random.uniform(tf.shape(self.logits)) sample = U.softmax(self.logits - tf.math.log(-tf.math.log(u)), axis=-1) return sample else: return U.softmax(self.logits, axis=-1)
def mode(self): return U.softmax(self.logits, axis=-1)