def gumbel_softmax(logits, temperature, hard, axis=-1, eps=1e-20): uniform_samples = ops.UniformReal()(logits.shape) gumbels = -ops.log(-ops.log(uniform_samples + eps) + eps) # ~Gumbel(0, 1) gumbels = (logits + gumbels) / temperature y_soft = ops.Softmax(axis)(gumbels) if hard: # Straight through index = y_soft.argmax(axis) y_hard = ops.OneHot(axis)(index, y_soft.shape[axis], ops.scalar_to_array(1.0), ops.scalar_to_array(0.0)) ret = ops.stop_gradient(y_hard - y_soft) + y_soft else: # Reparametrization trick. ret = y_soft return ret
def construct(self, logit, label): logit_max = self.max(logit, -1) exp = self.exp(self.sub(logit, logit_max)) exp_sum = self.sum(exp, -1) softmax_result = self.div(exp, exp_sum) if self.sparse: label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value) softmax_result_log = self.log(softmax_result) loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1) loss = self.mul2(ops.scalar_to_array(-1.0), loss) loss = self.mean(loss, -1) return loss