def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-name """Compute mixture of gaussians loss.""" ndims = targets.shape[-1] logits = preds[:, :ngauss] mus = preds[:, ngauss:ngauss * (ndims + 1)] sigmas = preds[:, ngauss(ndims + 1):] sigmas = sigmas * sigmas + 1e-6 # Make positive. loglogits = logits - math.logsumexp(logits, axis=-1, keepdims=True) mus = np.reshape(mus, [-1, ngauss, ndims]) sigmas = np.reshape(sigmas, [-1, ngauss, ndims]) targets = np.reshape(targets, [-1, 1, ndims]) glogprobs = log_gaussian_diag_pdf(targets, mus, sigmas) return math.logsumexp(loglogits + glogprobs, axis=-1)
def Softmax(x, axis=-1, **unused_kwargs): """Apply softmax to x: exponentiate and normalize along the given axis.""" return np.exp(x - math.logsumexp(x, axis, keepdims=True))
def LogSoftmax(x, axis=-1, **unused_kwargs): """Apply log softmax to x: log-normalize along the given axis.""" return x - math.logsumexp(x, axis, keepdims=True)