x_tfo_hat = torch.clamp(x_tfo_hat, 1e-9, 1 - 1e-9)
    BCE_loss = F.binary_cross_entropy(x_hat.view(-1, 28 * 28),
                                      x.view(-1, 28 * 28),
                                      reduction='sum')
    BCE_loss_tfo = F.binary_cross_entropy(x_tfo_hat.view(-1, 28 * 28),
                                          x_tfo.view(-1, 28 * 28),
                                          reduction='sum')
    KLD_loss = -0.5 * torch.sum(1 + torch.log(1e-8 + sigma.pow(2)) -
                                mean.pow(2) - sigma.pow(2))

    return (0.5 * BCE_loss + 0.5 * BCE_loss_tfo + KLD_loss) / x.size(0)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(MNISTDataset_pairs(
    '../Data Processing/processed_train_mnist.npz',
    transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()])),
                                           batch_size=128,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(MNISTDataset_pairs(
    '../Data Processing/processed_test_mnist.npz',
    transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()])),
                                          batch_size=500,
                                          shuffle=True)
# max_translation = 6
# max_rotation = 45

run_updated_ver = 1
network = CVAE_MLP_latent_tfo_modif(28 * 28, [250, 100], 50,
                                    norm_context=1).to(device)
Exemplo n.º 2
0
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from MNISTDataset import MNISTDataset_pairs
from model import *

def VAE_loss(x,x_hat,mean,sigma):
	x_hat = torch.clamp(x_hat,1e-9,1-1e-9)
	BCE_loss = F.binary_cross_entropy(x_hat.view(-1,28*28),x.view(-1,28*28),reduction='sum')
	KLD_loss = -0.5*torch.sum(1+torch.log(1e-8+sigma.pow(2))-mean.pow(2)-sigma.pow(2))

	return (BCE_loss+KLD_loss)/x.size(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(MNISTDataset_pairs('../Data Processing/processed_train_mnist.npz',transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])), batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(MNISTDataset_pairs('../Data Processing/processed_test_mnist.npz',transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])), batch_size=500, shuffle=True)
# max_translation = 6
# max_rotation = 45

enc_condition = 0
dec_condition = 1
network = CVAE_MLP(28*28,[250,100],50,enc_condition=enc_condition,dec_condition=dec_condition).to(device)
learning_rate = 0.002
learning_rate_decay = 0.001
optimizer = optim.Adam(network.parameters(), lr=learning_rate)

n_epochs = 50
log_interval = 200
test_interval = 5
train_losses = []