def forward(self, *batch): sender_input = batch[0] sender_input = sender_input.view(-1, 784) mu, logvar = self.sender(sender_input) if self.train: message = self.reparameterize(mu, logvar) else: message = mu receiver_output = self.receiver(message) BCE = F.binary_cross_entropy(receiver_output, sender_input, reduction="sum") KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss = BCE + KLD log = core.Interaction( sender_input=sender_input, receiver_input=None, labels=None, aux_input=None, receiver_output=receiver_output.detach(), message=message.detach(), message_length=torch.ones(message.size(0)), aux={}, ) return loss.mean(), log
def forward(self, *batch): sender_input = batch[0] _ = batch[1] # latent_values label = batch[2] distributions = self.sender(sender_input) mu = distributions[:, :self.z_dim] logvar = distributions[:, self.z_dim:] if self.train: message = reparametrize(mu, logvar) else: message = mu receiver_output = self.receiver(message) recon_loss = reconstruction_loss(sender_input, receiver_output) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) beta_vae_loss = recon_loss + self.beta * total_kld log = core.Interaction( sender_input=label, receiver_input=None, receiver_output=receiver_output.detach(), aux_input=None, message=message.detach(), labels=None, message_length=torch.ones(message.size(0)), aux={}, ) return beta_vae_loss.mean(), log