def test():
    mean = torch.tensor([[0, 0], [0.5, 0.5]], dtype=torch.float32)
    sig = torch.tensor([[1, 1], [2, 2]], dtype=torch.float32)
    print(mean.shape)
    bbb = torch.zeros(mean.shape)
    ccc = torch.ones(sig.shape)
    dist = Normal(bbb, ccc).sample()

    pro = Normal(bbb, ccc).log_prob(dist)
    print(pro)

    bb = dist.pow(2)
    print(bb)
    print(bb - 1)
    print(bb.sum(-1, keepdim=True))