Esempio n. 1
0
def entropy_loss(arch_params):
    loss = []
    for arch_param in arch_params:
        probs = Bernoulli(logits=arch_param)
        loss.append(probs.entropy().mean())
    loss = torch.mean(torch.stack(loss))
    return loss
Esempio n. 2
0
 def entropy(self, x):
     p = self.forward(x)
     m = Bernoulli(p)
     return m.entropy().mean()