x_tfo_hat = torch.clamp(x_tfo_hat, 1e-9, 1 - 1e-9) BCE_loss = F.binary_cross_entropy(x_hat.view(-1, 28 * 28), x.view(-1, 28 * 28), reduction='sum') BCE_loss_tfo = F.binary_cross_entropy(x_tfo_hat.view(-1, 28 * 28), x_tfo.view(-1, 28 * 28), reduction='sum') KLD_loss = -0.5 * torch.sum(1 + torch.log(1e-8 + sigma.pow(2)) - mean.pow(2) - sigma.pow(2)) return (0.5 * BCE_loss + 0.5 * BCE_loss_tfo + KLD_loss) / x.size(0) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader(MNISTDataset_pairs( '../Data Processing/processed_train_mnist.npz', transform=torchvision.transforms.Compose( [torchvision.transforms.ToTensor()])), batch_size=128, shuffle=True) test_loader = torch.utils.data.DataLoader(MNISTDataset_pairs( '../Data Processing/processed_test_mnist.npz', transform=torchvision.transforms.Compose( [torchvision.transforms.ToTensor()])), batch_size=500, shuffle=True) # max_translation = 6 # max_rotation = 45 run_updated_ver = 1 network = CVAE_MLP_latent_tfo_modif(28 * 28, [250, 100], 50, norm_context=1).to(device)
import torch.nn.functional as F import torch.optim as optim import torchvision from MNISTDataset import MNISTDataset_pairs from model import * def VAE_loss(x,x_hat,mean,sigma): x_hat = torch.clamp(x_hat,1e-9,1-1e-9) BCE_loss = F.binary_cross_entropy(x_hat.view(-1,28*28),x.view(-1,28*28),reduction='sum') KLD_loss = -0.5*torch.sum(1+torch.log(1e-8+sigma.pow(2))-mean.pow(2)-sigma.pow(2)) return (BCE_loss+KLD_loss)/x.size(0) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") train_loader = torch.utils.data.DataLoader(MNISTDataset_pairs('../Data Processing/processed_train_mnist.npz',transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])), batch_size=128, shuffle=True) test_loader = torch.utils.data.DataLoader(MNISTDataset_pairs('../Data Processing/processed_test_mnist.npz',transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])), batch_size=500, shuffle=True) # max_translation = 6 # max_rotation = 45 enc_condition = 0 dec_condition = 1 network = CVAE_MLP(28*28,[250,100],50,enc_condition=enc_condition,dec_condition=dec_condition).to(device) learning_rate = 0.002 learning_rate_decay = 0.001 optimizer = optim.Adam(network.parameters(), lr=learning_rate) n_epochs = 50 log_interval = 200 test_interval = 5 train_losses = []