Ejemplo n.º 1
0
    def forward(self, *batch):
        sender_input = batch[0]
        sender_input = sender_input.view(-1, 784)
        mu, logvar = self.sender(sender_input)

        if self.train:
            message = self.reparameterize(mu, logvar)
        else:
            message = mu

        receiver_output = self.receiver(message)

        BCE = F.binary_cross_entropy(receiver_output, sender_input, reduction="sum")

        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss = BCE + KLD

        log = core.Interaction(
            sender_input=sender_input,
            receiver_input=None,
            labels=None,
            aux_input=None,
            receiver_output=receiver_output.detach(),
            message=message.detach(),
            message_length=torch.ones(message.size(0)),
            aux={},
        )

        return loss.mean(), log
Ejemplo n.º 2
0
    def forward(self, *batch):
        sender_input = batch[0]
        _ = batch[1]  # latent_values
        label = batch[2]

        distributions = self.sender(sender_input)
        mu = distributions[:, :self.z_dim]
        logvar = distributions[:, self.z_dim:]

        if self.train:
            message = reparametrize(mu, logvar)
        else:
            message = mu

        receiver_output = self.receiver(message)

        recon_loss = reconstruction_loss(sender_input, receiver_output)
        total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)

        beta_vae_loss = recon_loss + self.beta * total_kld

        log = core.Interaction(
            sender_input=label,
            receiver_input=None,
            receiver_output=receiver_output.detach(),
            aux_input=None,
            message=message.detach(),
            labels=None,
            message_length=torch.ones(message.size(0)),
            aux={},
        )

        return beta_vae_loss.mean(), log