Ejemplo n.º 1
0
def gumbel_softmax(logits, temperature, hard, axis=-1, eps=1e-20):
    uniform_samples = ops.UniformReal()(logits.shape)
    gumbels = -ops.log(-ops.log(uniform_samples + eps) + eps) # ~Gumbel(0, 1)
    gumbels = (logits + gumbels) / temperature
    y_soft = ops.Softmax(axis)(gumbels)

    if hard:
        # Straight through
        index = y_soft.argmax(axis)
        y_hard = ops.OneHot(axis)(index, y_soft.shape[axis], ops.scalar_to_array(1.0), ops.scalar_to_array(0.0))
        ret = ops.stop_gradient(y_hard - y_soft) + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret
Ejemplo n.º 2
0
    def construct(self, logit, label):
        logit_max = self.max(logit, -1)
        exp = self.exp(self.sub(logit, logit_max))
        exp_sum = self.sum(exp, -1)
        softmax_result = self.div(exp, exp_sum)
        if self.sparse:
            label = self.onehot(label,
                                ops.shape(logit)[1], self.on_value,
                                self.off_value)
        softmax_result_log = self.log(softmax_result)
        loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)),
                                      -1)
        loss = self.mul2(ops.scalar_to_array(-1.0), loss)
        loss = self.mean(loss, -1)

        return loss