def log_mvn_likelihood(mean: torch.FloatTensor, covariance: torch.FloatTensor, observation: torch.FloatTensor) -> torch.FloatTensor: """ all torch primitives all non-diagonal elements of covariance matrix are assumed to be zero """ k = mean.shape[0] variances = covariance.diag() log_likelihood = 0 for i in range(k): log_likelihood += - 0.5 * torch.log(variances[i]) \ - 0.5 * k * math.log(2 * math.pi) \ - 0.5 * ((observation[i] - mean[i])**2 / variances[i]) return log_likelihood