def main(): # train args parser = argparse.ArgumentParser( description='Disributional Sliced Wasserstein Autoencoder') parser.add_argument('--datadir', default='/user/HS229/xc00414/condor_examples/DSW/DSW/', help='path to dataset') parser.add_argument( '--outdir', default='/user/HS229/xc00414/condor_examples/DSW/DSW/result', help='directory to output images') parser.add_argument('--batch-size', type=int, default=512, metavar='N', help='input batch size for training (default: 512)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)') parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument('--lrpsw', type=float, default=0.001, metavar='LR', help='learning rate psw (default: 0.001)') parser.add_argument( '--num-workers', type=int, default=16, metavar='N', help='number of dataloader workers if device is CPU (default: 16)') parser.add_argument('--seed', type=int, default=16, metavar='S', help='random seed (default: 16)') parser.add_argument('--g', type=str, default='circular', help='g') parser.add_argument('--num-projection', type=int, default=1000, help='number projection') parser.add_argument('--lam', type=float, default=1, help='Regularization strength') parser.add_argument('--p', type=int, default=2, help='Norm p') parser.add_argument('--niter', type=int, default=10, help='number of iterations') parser.add_argument('--r', type=float, default=1000, help='R') parser.add_argument('--kappa', type=float, default=50, help='R') parser.add_argument('--k', type=int, default=10, help='R') parser.add_argument('--e', type=float, default=1000, help='R') parser.add_argument('--latent-size', type=int, default=32, help='Latent size') parser.add_argument('--hsize', type=int, default=100, help='h size') parser.add_argument('--dim', type=int, default=100, help='subspace size') parser.add_argument('--dataset', type=str, default='MNIST', help='(MNIST|FMNIST)') parser.add_argument( '--model-type', type=str, required=True, help='(ASWD|SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD)') args = parser.parse_args() #torch.random.manual_seed(args.seed) if (args.g == 'circular'): g_function = circular_function model_type = args.model_type latent_size = args.latent_size num_projection = args.num_projection dataset = args.dataset assert dataset in ['MNIST'] assert model_type in [ 'MGSWNN', 'JMGSWNN', 'SWD', 'MSWD', 'PSTW', 'PSW', 'MGSWD', 'DSWD', 'GSWD', 'DGSWD', 'JSWD', 'JMSWD', 'JMGSWD', 'JDSWD', 'JGSWD', 'JDGSWD', 'DGSWNN', 'JDGSWNN', 'GSWNN', 'JGSWNN', 'ASWD' ] if (model_type == 'SWD' or model_type == 'JSWD'): model_dir = os.path.join(args.outdir, model_type + '_n' + str(num_projection)) elif model_type == 'GSWNN': model_dir = os.path.join(args.outdir, model_type + '_n' + str(num_projection)) elif (model_type == 'GSWD' or model_type == 'JGSWD'): model_dir = os.path.join( args.outdir, model_type + '_n' + str(num_projection) + '_' + args.g + str(args.r)) elif (model_type == 'DSWD' or model_type == 'JDSWD' or model_type == 'ASWD'): model_dir = os.path.join( args.outdir, model_type + '_iter' + str(args.niter) + '_n' + str(num_projection) + '_lam' + str(args.lam)) elif (model_type == 'DGSWD' or model_type == 'JDGSWD'): model_dir = os.path.join( args.outdir, model_type + '_iter' + str(args.niter) + '_n' + str(num_projection) + '_lam' + str(args.lam) + '_' + args.g + str(args.r)) elif (model_type == 'MSWD' or model_type == 'JMSWD'): model_dir = os.path.join(args.outdir, model_type) elif (model_type == 'MGSWNN' or model_type == 'JMGSWNN'): model_dir = os.path.join(args.outdir, model_type + '_size' + str(args.hsize)) elif (model_type == 'MGSWD' or model_type == 'JMGSWD'): model_dir = os.path.join(args.outdir, model_type + '_' + args.g) elif (model_type == 'PSW' or model_type == 'JPSW'): model_dir = os.path.join( args.outdir, model_type + '_e' + str(args.e) + '_iter' + str(args.niter) + '_d' + str(args.dim)) + '_lr' + str(args.lrpsw) elif (model_type == 'PSTW' or model_type == 'JPSTW'): model_dir = os.path.join( args.outdir, model_type + '_iter' + str(args.niter) + '_d' + str(args.dim)) + '_lr' + str(args.lrpsw) if not (os.path.isdir(args.datadir)): os.makedirs(args.datadir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(model_dir)): os.makedirs(model_dir) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format( args.batch_size, args.epochs, args.lr, device.type)) # build train and test set data loaders if (dataset == 'MNIST'): image_size = 28 num_chanel = 1 train_loader = torch.utils.data.DataLoader( datasets.MNIST(args.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) test_loader = torch.utils.data.DataLoader(datasets.MNIST( args.datadir, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])), batch_size=10000, shuffle=False, num_workers=args.num_workers) test_loader2 = torch.utils.data.DataLoader( datasets.MNIST(args.datadir, train=False, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=64, shuffle=False, num_workers=args.num_workers) model = MnistAutoencoder(image_size=28, latent_size=args.latent_size, hidden_size=100, device=device).to(device) if (model_type == 'DSWD' or model_type == 'DGSWD'): transform_net = TransformNet(28 * 28).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # op_trannet = optim.Adam(transform_net.parameters(), lr=1e-4) # train_net(28 * 28, 1000, transform_net, op_trannet) elif (model_type == 'JDSWD' or model_type == 'JDSWD2' or model_type == 'JDGSWD'): transform_net = TransformNet(args.latent_size + 28 * 28).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet) elif (model_type == 'ASWD'): phi = Mapping(28 * 28).to(device) phi_op = optim.Adam(phi.parameters(), lr=0.0005, betas=(0.9, 0.999)) if (model_type == 'MGSWNN'): gsw = GSW_NN(din=28 * 28, nofprojections=1, model_depth=3, num_filters=args.hsize, use_cuda=True) if (model_type == 'GSWNN' or model_type == 'DGSWNN'): gsw = GSW_NN(din=28 * 28, nofprojections=num_projection, model_depth=3, num_filters=args.hsize, use_cuda=True) if (model_type == 'JMGSWNN'): gsw = GSW_NN(din=28 * 28 + 32, nofprojections=1, model_depth=3, num_filters=args.hsize, use_cuda=True) if (model_type == 'MSWD' or model_type == 'JMSWD'): gsw = GSW() if (model_type == 'MGSWD'): theta = torch.randn((1, 784), device=device, requires_grad=True) theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1)) #opt_theta = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) if (model_type == 'JMGSWD'): theta = torch.randn((1, 784 + 32), device=device, requires_grad=True) theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1)) opt_theta = torch.optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) fixednoise = torch.randn((64, latent_size)).to(device) ite = 0 wd_list = [] swd_list = [] save_idx = str(time.time()).split('.') save_idx = save_idx[0] + save_idx[1] for epoch in range(args.epochs): total_loss = 0.0 for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)): tic = time.time() if (model_type == 'SWD'): loss = model.compute_loss_SWD(data, torch.randn, num_projection, p=args.p) elif (model_type == 'ASWD'): loss = model.compute_lossTGSWD(data, torch.randn, num_projection, phi, phi_op, p=2, max_iter=args.niter, lam=args.lam, net_type='fc') elif (model_type == 'GSWD'): loss = model.compute_loss_GSWD(data, torch.randn, g_function, args.r, num_projection, p=2) elif (model_type == 'GSWNN'): loss = model.compute_loss_GSWNN(data, torch.randn, gsw, p=2) elif (model_type == 'DGSWNN'): loss = model.compute_loss_DGSWNN(data, torch.randn, gsw, args.niter, args.lam, args.lr, p=2) elif (model_type == 'PSW'): loss = model.compute_loss_PSW(data, torch.randn, pdim=args.dim, p=args.p, n_iter=args.niter, n_iter_sinkhorn=args.niters, e=args.e, lr=args.lrpsw) elif (model_type == 'PSTW'): loss = model.compute_loss_PSTW(data, torch.randn, pdim=args.dim, p=args.p, n_iter=args.niter, lr=args.lrpsw) elif (model_type == 'MGSWNN'): loss = model.compute_loss_MGSWNN(data, torch.randn, gsw, args.niter, p=args.p) elif (model_type == 'JMGSWNN'): loss = model.compute_loss_JMGSWNN(data, torch.randn, gsw, args.niter, p=args.p) elif (model_type == 'MSWD'): loss = model.compute_loss_MSWD(data, torch.randn, gsw, args.niter) elif (model_type == 'MGSWD'): loss = model.compute_loss_MGSWD(data, torch.randn, g_function, args.r, p=args.p, max_iter=args.niter) elif (model_type == 'DSWD'): loss = model.compute_lossDSWD(data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'DGSWD'): loss = model.compute_lossDGSWD(data, torch.randn, num_projection, transform_net, op_trannet, g_function, r=args.r, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'JSWD'): loss = model.compute_loss_JSWD(data, torch.randn, num_projection, p=args.p) elif (model_type == 'JGSWD'): loss = model.compute_loss_JGSWD(data, torch.randn, g_function, args.r, num_projection, p=args.p) elif (model_type == 'JDSWD'): loss = model.compute_lossJDSWD(data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'JDGSWD'): loss = model.compute_lossJDGSWD(data, torch.randn, num_projection, transform_net, op_trannet, g_function, r=args.r, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'JMSWD'): loss = model.compute_loss_JMSWD(data, torch.randn, gsw, args.niter) elif (model_type == 'JMGSWD'): loss = model.compute_loss_JMGSWD(data, torch.randn, theta, opt_theta, g_function, args.r, p=args.p, max_iter=args.niter) optimizer.zero_grad() loss.backward() optimizer.step() toc = time.time() print('execution_time:', toc - tic, model_type) total_loss += loss.item() if (ite % 100 == 0): model.eval() for _, (input, y) in enumerate(test_loader, start=0): fixednoise_wd = torch.randn( (10000, latent_size)).to(device) data = input.to(device) data = data.view(data.shape[0], -1) fake = model.decoder(fixednoise_wd) wd_list.append( compute_true_Wasserstein(data.to('cpu'), fake.to('cpu'))) swd_list.append( sliced_wasserstein_distance(data, fake, 10000).item()) print("Iter:" + str(ite) + " WD: " + str(wd_list[-1])) np.savetxt(model_dir + "/wd" + save_idx + ".csv", wd_list, delimiter=",") print("Iter:" + str(ite) + " SWD: " + str(swd_list[-1])) np.savetxt(model_dir + "/swd" + save_idx + ".csv", swd_list, delimiter=",") break model.train() ite = ite + 1 total_loss /= (batch_idx + 1) print("Epoch: " + str(epoch) + " Loss: " + str(total_loss)) if (epoch % 1 == 0): model.eval() sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png", fixednoise, model.decoder, 64, image_size, num_chanel) if (model_type[0] == 'J'): for _, (input, y) in enumerate(test_loader2, start=0): input = input.to(device) input = input.view(-1, image_size**2) reconstruct( model_dir + '/reconstruction_epoch_' + str(epoch) + ".png", input, model.encoder, model.decoder, image_size, num_chanel, device) break model.train() save_dmodel(model, optimizer, None, None, None, None, epoch, model_dir) if (epoch == args.epochs - 1): model.eval() sampling_eps(model_dir + '/sample_epoch_' + str(epoch), fixednoise, model.decoder, 64, image_size, num_chanel) model.train()
def main(): # train args parser = argparse.ArgumentParser( description="Disributional Sliced Wasserstein Autoencoder") parser.add_argument("--datadir", default="./", help="path to dataset") parser.add_argument("--outdir", default="./result", help="directory to output images") parser.add_argument("--batch-size", type=int, default=512, metavar="N", help="input batch size for training (default: 512)") parser.add_argument("--epochs", type=int, default=200, metavar="N", help="number of epochs to train (default: 200)") parser.add_argument("--lr", type=float, default=0.0005, metavar="LR", help="learning rate (default: 0.0005)") parser.add_argument( "--num-workers", type=int, default=16, metavar="N", help="number of dataloader workers if device is CPU (default: 16)", ) parser.add_argument("--seed", type=int, default=16, metavar="S", help="random seed (default: 16)") parser.add_argument("--g", type=str, default="circular", help="g") parser.add_argument("--num-projection", type=int, default=1000, help="number projection") parser.add_argument("--lam", type=float, default=1, help="Regularization strength") parser.add_argument("--p", type=int, default=2, help="Norm p") parser.add_argument("--niter", type=int, default=10, help="number of iterations") parser.add_argument("--r", type=float, default=1000, help="R") parser.add_argument("--kappa", type=float, default=50, help="R") parser.add_argument("--k", type=int, default=10, help="R") parser.add_argument("--e", type=float, default=1000, help="R") parser.add_argument("--latent-size", type=int, default=32, help="Latent size") parser.add_argument("--hsize", type=int, default=100, help="h size") parser.add_argument("--dataset", type=str, default="MNIST", help="(MNIST|FMNIST)") parser.add_argument( "--model-type", type=str, required=True, help="(SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD)") args = parser.parse_args() torch.random.manual_seed(args.seed) if args.g == "circular": g_function = circular_function model_type = args.model_type latent_size = args.latent_size num_projection = args.num_projection dataset = args.dataset assert dataset in ["MNIST"] assert model_type in [ "MGSWNN", "JMGSWNN", "SWD", "MSWD", "MGSWD", "DSWD", "GSWD", "DGSWD", "JSWD", "JMSWD", "JMGSWD", "JDSWD", "JGSWD", "JDGSWD", "DGSWNN", "JDGSWNN", "GSWNN", "JGSWNN", ] if model_type == "SWD" or model_type == "JSWD": model_dir = os.path.join(args.outdir, model_type + "_n" + str(num_projection)) elif model_type == "GSWD" or model_type == "JGSWD": model_dir = os.path.join( args.outdir, model_type + "_n" + str(num_projection) + "_" + args.g + str(args.r)) elif model_type == "DSWD" or model_type == "JDSWD": model_dir = os.path.join( args.outdir, model_type + "_iter" + str(args.niter) + "_n" + str(num_projection) + "_lam" + str(args.lam)) elif model_type == "DGSWD" or model_type == "JDGSWD": model_dir = os.path.join( args.outdir, model_type + "_iter" + str(args.niter) + "_n" + str(num_projection) + "_lam" + str(args.lam) + "_" + args.g + str(args.r), ) elif model_type == "MSWD" or model_type == "JMSWD": model_dir = os.path.join(args.outdir, model_type) elif model_type == "MGSWNN" or model_type == "JMGSWNN": model_dir = os.path.join(args.outdir, model_type + "_size" + str(args.hsize)) elif model_type == "MGSWD" or model_type == "JMGSWD": model_dir = os.path.join(args.outdir, model_type + "_" + args.g) if not (os.path.isdir(args.datadir)): os.makedirs(args.datadir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(model_dir)): os.makedirs(model_dir) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print("batch size {}\nepochs {}\nAdam lr {} \n using device {}\n".format( args.batch_size, args.epochs, args.lr, device.type)) # build train and test set data loaders if dataset == "MNIST": image_size = 28 num_chanel = 1 train_loader = torch.utils.data.DataLoader( datasets.MNIST(args.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST(args.datadir, train=False, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=10000, shuffle=False, num_workers=args.num_workers, ) test_loader2 = torch.utils.data.DataLoader( datasets.MNIST(args.datadir, train=False, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=64, shuffle=False, num_workers=args.num_workers, ) model = MnistAutoencoder(image_size=28, latent_size=args.latent_size, hidden_size=100, device=device).to(device) if model_type == "DSWD" or model_type == "DGSWD": transform_net = TransformNet(28 * 28).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # op_trannet = optim.Adam(transform_net.parameters(), lr=1e-4) # train_net(28 * 28, 1000, transform_net, op_trannet) elif model_type == "JDSWD" or model_type == "JDSWD2" or model_type == "JDGSWD": transform_net = TransformNet(args.latent_size + 28 * 28).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet) if model_type == "MGSWNN": gsw = GSW_NN(din=28 * 28, nofprojections=1, model_depth=3, num_filters=args.hsize, use_cuda=True) if model_type == "GSWNN" or model_type == "DGSWNN": gsw = GSW_NN(din=28 * 28, nofprojections=num_projection, model_depth=3, num_filters=args.hsize, use_cuda=True) if model_type == "JMGSWNN": gsw = GSW_NN(din=28 * 28 + 32, nofprojections=1, model_depth=3, num_filters=args.hsize, use_cuda=True) if model_type == "MSWD" or model_type == "JMSWD": gsw = GSW() if model_type == "MGSWD": theta = torch.randn((1, 784), device=device, requires_grad=True) theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1)) opt_theta = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) if model_type == "JMGSWD": theta = torch.randn((1, 784 + 32), device=device, requires_grad=True) theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1)) opt_theta = torch.optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) fixednoise = torch.randn((64, latent_size)).to(device) ite = 0 wd_list = [] swd_list = [] for epoch in range(args.epochs): total_loss = 0.0 for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)): if model_type == "SWD": loss = model.compute_loss_SWD(data, torch.randn, num_projection, p=args.p) elif model_type == "GSWD": loss = model.compute_loss_GSWD(data, torch.randn, g_function, args.r, num_projection, p=2) elif model_type == "GSWNN": loss = model.compute_loss_GSWNN(data, torch.randn, gsw, p=2) elif model_type == "DGSWNN": loss = model.compute_loss_DGSWNN(data, torch.randn, gsw, args.niter, args.lam, args.lr, p=2) elif model_type == "MGSWNN": loss = model.compute_loss_MGSWNN(data, torch.randn, gsw, args.niter, p=args.p) elif model_type == "JMGSWNN": loss = model.compute_loss_JMGSWNN(data, torch.randn, gsw, args.niter, p=args.p) elif model_type == "MSWD": loss = model.compute_loss_MSWD(data, torch.randn, gsw, args.niter) elif model_type == "MGSWD": loss = model.compute_loss_MGSWD(data, torch.randn, theta, opt_theta, g_function, args.r, args.lr2, p=args.p, max_iter=args.niter) elif model_type == "DSWD": loss = model.compute_lossDSWD( data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam, ) elif model_type == "DGSWD": loss = model.compute_lossDGSWD( data, torch.randn, num_projection, transform_net, op_trannet, g_function, r=args.r, p=args.p, max_iter=args.niter, lam=args.lam, ) elif model_type == "JSWD": loss = model.compute_loss_JSWD(data, torch.randn, num_projection, p=args.p) elif model_type == "JGSWD": loss = model.compute_loss_JGSWD(data, torch.randn, g_function, args.r, num_projection, p=args.p) elif model_type == "JDSWD": loss = model.compute_lossJDSWD( data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam, ) elif model_type == "JDGSWD": loss = model.compute_lossJDGSWD( data, torch.randn, num_projection, transform_net, op_trannet, g_function, r=args.r, p=args.p, max_iter=args.niter, lam=args.lam, ) elif model_type == "JMSWD": loss = model.compute_loss_JMSWD(data, torch.randn, gsw, args.niter) elif model_type == "JMGSWD": loss = model.compute_loss_JMGSWD(data, torch.randn, theta, opt_theta, g_function, args.r, p=args.p, max_iter=args.niter) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() if ite % 100 == 0: model.eval() for _, (input, y) in enumerate(test_loader, start=0): fixednoise_wd = torch.randn( (10000, latent_size)).to(device) data = input.to(device) data = data.view(data.shape[0], -1) fake = model.decoder(fixednoise_wd) wd_list.append( compute_true_Wasserstein(data.to("cpu"), fake.to("cpu"))) swd_list.append( sliced_wasserstein_distance(data, fake, 10000).item()) print("Iter:" + str(ite) + " WD: " + str(wd_list[-1])) np.savetxt(model_dir + "/wd.csv", wd_list, delimiter=",") print("Iter:" + str(ite) + " SWD: " + str(swd_list[-1])) np.savetxt(model_dir + "/swd.csv", swd_list, delimiter=",") break model.train() ite = ite + 1 total_loss /= batch_idx + 1 print("Epoch: " + str(epoch) + " Loss: " + str(total_loss)) if epoch % 1 == 0: model.eval() sampling( model_dir + "/sample_epoch_" + str(epoch) + ".png", fixednoise, model.decoder, 64, image_size, num_chanel, ) if model_type[0] == "J": for _, (input, y) in enumerate(test_loader2, start=0): input = input.to(device) input = input.view(-1, image_size**2) reconstruct( model_dir + "/reconstruction_epoch_" + str(epoch) + ".png", input, model.encoder, model.decoder, image_size, num_chanel, device, ) break model.train() save_dmodel(model, optimizer, None, None, None, None, epoch, model_dir) if epoch == args.epochs - 1: model.eval() sampling_eps(model_dir + "/sample_epoch_" + str(epoch), fixednoise, model.decoder, 64, image_size, num_chanel) model.train()
def main(): # train args parser = argparse.ArgumentParser( description='Augmented Sliced Wasserstein Autoencoder') parser.add_argument('--datadir', default='./', help='path to dataset') parser.add_argument('--outdir', default='./result/', help='directory to output images') parser.add_argument('--batch-size', type=int, default=512, metavar='N', help='input batch size for training (default: 512)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)') parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument( '--num-workers', type=int, default=16, metavar='N', help='number of dataloader workers if device is CPU (default: 16)') parser.add_argument('--seed', type=int, default=11, metavar='S', help='random seed (default: 16)') parser.add_argument('--g', type=str, default='circular', help='g') parser.add_argument('--num-projection', type=int, default=1000, help='number projection') parser.add_argument('--lam', type=float, default=0.5, help='Regularization strength') parser.add_argument('--p', type=int, default=2, help='Norm p') parser.add_argument('--niter', type=int, default=5, help='number of iterations') parser.add_argument('--r', type=float, default=1000, help='R') parser.add_argument('--latent-size', type=int, default=32, help='Latent size') parser.add_argument('--dataset', type=str, default='CIFAR', help='(CELEBA|CIFAR)') parser.add_argument('--model-type', type=str, required=True, help='(ASWD|SWD|MSWD|DSWD|GSWD|)') parser.add_argument('--gpu', type=str, required=False, default=0) args = parser.parse_args() torch.random.manual_seed(args.seed) if (args.g == 'circular'): g_function = circular_function model_type = args.model_type latent_size = args.latent_size num_projection = args.num_projection dataset = args.dataset model_dir = os.path.join(args.outdir, model_type) assert dataset in ['CELEBA', 'CIFAR'] assert model_type in ['ASWD', 'SWD', 'MSWD', 'DSWD', 'GSWD'] if not (os.path.isdir(args.datadir)): os.makedirs(args.datadir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(model_dir)): os.makedirs(model_dir) use_cuda = torch.cuda.is_available() device = torch.device("cuda:" + str(args.gpu) if use_cuda else "cpu") print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format( args.batch_size, args.epochs, args.lr, device.type)) if (dataset == 'CIFAR'): from DCGANAE import Discriminator image_size = 64 num_chanel = 3 train_loader = torch.utils.data.DataLoader( datasets.CIFAR10( args.datadir, train=True, download=True, transform=transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) elif (dataset == 'CELEBA'): from DCGANAE import Discriminator image_size = 64 num_chanel = 3 dataset = CustomDataset( root=args.datadir + 'img_align_celeba', transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) # Create the dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) model = DCGANAE(image_size=64, latent_size=latent_size, num_chanel=3, hidden_chanels=64, device=device).to(device) #model=nn.DataParallel(model) #model.to(device) dis = Discriminator(64, args.latent_size, 3, 64).to(device) disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999)) if (model_type == 'DSWD' or model_type == 'DGSWD'): transform_net = TransformNet(64 * 8 * 4 * 4).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=0.0005, betas=(0.5, 0.999)) train_net(64 * 8 * 4 * 4, 1000, transform_net, op_trannet, device=device) if model_type == 'ASWD': phi = Mapping(64 * 8 * 4 * 4).to(device) phi_op = optim.Adam(phi.parameters(), lr=0.001, betas=(0.5, 0.999)) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) epoch_cont = 0 generated_sample_number = 64 fixednoise = torch.randn((generated_sample_number, latent_size)).to(device) loss_recorder = [] generated_sample_size = 64 W2_recorder = np.zeros([41, 40]) save_idx = str(time.time()).split('.') save_idx = save_idx[0] + save_idx[1] path_0 = model_dir + '/' + args.dataset + '/fid/' + save_idx os.mkdir(path_0) if args.dataset == 'CIFAR': interval_ = 10 fid_stats_file = args.datadir + '/fid_stats_cifar10_train.npz' else: interval_ = 5 fid_stats_file = args.datadir + '/fid_stats_celeba.npz' fid_recorder = np.zeros(args.epochs // interval_ + 1) for epoch in range(epoch_cont, args.epochs): total_loss = 0.0 for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)): if (model_type == 'SWD'): loss = model.compute_loss_SWD(dis, disoptimizer, data, torch.randn, num_projection, p=args.p, epoch=epoch, batch_idx=batch_idx) elif (model_type == 'GSWD'): loss = model.compute_loss_GSWD(dis, disoptimizer, data, torch.randn, g_function, args.r, num_projection, p=args.p) elif (model_type == 'MSWD'): loss, v = model.compute_loss_MSWD(dis, disoptimizer, data, torch.randn, p=args.p, max_iter=args.niter) elif (model_type == 'DSWD'): loss = model.compute_lossDSWD(dis, disoptimizer, data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=10) elif (model_type == 'DGSWD'): loss = model.compute_lossDGSWD(dis, disoptimizer, data, torch.randn, num_projection, transform_net, op_trannet, g_function, args.r, p=args.p, max_iter=args.niter, lam=10) elif (model_type == 'SINKHORN'): loss = model.compute_loss_sinkhorn(dis, disoptimizer, data, torch.randn, p=2, n_iter=100, e=100) elif (model_type == 'CRAMER'): loss = model.compute_loss_cramer(dis, disoptimizer, data, torch.randn) elif model_type == 'ASWD': loss = model.compute_lossASWD(dis, disoptimizer, data, torch.randn, num_projection, phi, phi_op, p=2, max_iter=args.niter, lam=args.lam, epoch=epoch, batch_idx=batch_idx) optimizer.zero_grad() total_loss += loss.item() loss.backward() optimizer.step() if epoch == 0 or (epoch + 1) % interval_ == 0: path_1 = path_0 + '/' + str('%03d' % epoch) os.mkdir(path_1) for j in range(20): fixednoise_ = torch.randn((1000, 32)).to(device) imgs = model.decoder(fixednoise_) for i, img in enumerate(imgs): img = img.transpose(0, -1).transpose( 0, 1).cpu().detach().numpy() img = (img * 255).astype(np.uint8) imageio.imwrite( path_1 + '/' + args.model_type + '_' + str(args.num_projection) + '_' + str('%03d' % epoch) + '_' + str(1000 * j + i) + '.png', img) fid_value = calculate_fid_given_paths( [path_1 + '/', fid_stats_file], 50, True, 2048) fid_recorder[(epoch + 1) // interval_] = fid_value np.save( path_0 + '/fid_recorder_' + 'np_' + str(num_projection) + '.npy', fid_recorder) print('fid score:', fid_value) os.system("rm -rf " + path_1) total_loss /= (batch_idx + 1) loss_recorder.append(total_loss) print("Epoch: " + str(epoch) + " Loss: " + str(total_loss)) if (epoch % 1 == 0 or epoch == args.epochs - 1): sampling( model_dir + '/' + args.dataset + '/sample_epoch_' + str(epoch) + ".png", fixednoise, model.decoder, generated_sample_number, generated_sample_size, num_chanel) torch.save( model.state_dict(), args.outdir + args.dataset + '/' + model_type + '_' + str(args.batch_size) + '_' + str(num_projection) + '_' + str(latent_size) + '_model.pth') torch.save( dis.state_dict(), args.outdir + args.dataset + '/' + model_type + '_' + str(args.batch_size) + '_' + str(num_projection) + '_' + str(latent_size) + '_discriminator.pth') np.save( args.outdir + args.dataset + '/' + model_type + '_' + str(args.batch_size) + '_' + str(num_projection) + '_' + str(latent_size) + '_loss.npy', loss_recorder)
def main(): # train args parser = argparse.ArgumentParser( description='Disributional Sliced Wasserstein Autoencoder') parser.add_argument('--datadir', default='./', help='path to dataset') parser.add_argument('--outdir', default='./result', help='directory to output images') parser.add_argument('--batch-size', type=int, default=512, metavar='N', help='input batch size for training (default: 512)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)') parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument( '--num-workers', type=int, default=16, metavar='N', help='number of dataloader workers if device is CPU (default: 16)') parser.add_argument('--seed', type=int, default=16, metavar='S', help='random seed (default: 16)') parser.add_argument('--g', type=str, default='circular', help='g') parser.add_argument('--num-projection', type=int, default=1000, help='number projection') parser.add_argument('--lam', type=float, default=1, help='Regularization strength') parser.add_argument('--p', type=int, default=2, help='Norm p') parser.add_argument('--niter', type=int, default=10, help='number of iterations') parser.add_argument('--r', type=float, default=1000, help='R') parser.add_argument('--latent-size', type=int, default=32, help='Latent size') parser.add_argument('--dataset', type=str, default='MNIST', help='(CELEBA|CIFAR)') parser.add_argument('--model-type', type=str, required=True, help='(SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER|)') args = parser.parse_args() torch.random.manual_seed(args.seed) if (args.g == 'circular'): g_function = circular_function model_type = args.model_type latent_size = args.latent_size num_projection = args.num_projection dataset = args.dataset model_dir = os.path.join(args.outdir, model_type) assert dataset in ['CELEBA', 'CIFAR'] assert model_type in ['SWD', 'MSWD', 'DSWD', 'GSWD', 'DGSWD', 'CRAMER'] if not (os.path.isdir(args.datadir)): os.makedirs(args.datadir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(model_dir)): os.makedirs(model_dir) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format( args.batch_size, args.epochs, args.lr, device.type)) if (dataset == 'CIFAR'): from DCGANAE import Discriminator image_size = 64 num_chanel = 3 train_loader = torch.utils.data.DataLoader( datasets.CIFAR10( args.datadir, train=True, download=True, transform=transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) elif (dataset == 'CELEBA'): from DCGANAE import Discriminator image_size = 64 num_chanel = 3 dataset = CustomDataset( root=args.datadir + '/img_align_celeba', transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) # Create the dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) model = DCGANAE(image_size=64, latent_size=latent_size, num_chanel=3, hidden_chanels=64, device=device).to(device) dis = Discriminator(64, args.latent_size, 3, 64).to(device) disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999)) if (model_type == 'DSWD' or model_type == 'DGSWD'): transform_net = TransformNet(64 * 8 * 4 * 4).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) train_net(64 * 8 * 4 * 4, 1000, transform_net, op_trannet) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) epoch_cont = 0 fixednoise = torch.randn((64, latent_size)).to(device) for epoch in range(epoch_cont, args.epochs): total_loss = 0.0 for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)): if (model_type == 'SWD'): loss = model.compute_loss_SWD(dis, disoptimizer, data, torch.randn, num_projection, p=args.p) elif (model_type == 'GSWD'): loss = model.compute_loss_GSWD(dis, disoptimizer, data, torch.randn, g_function, args.r, num_projection, p=args.p) elif (model_type == 'MSWD'): loss, v = model.compute_loss_MSWD(dis, disoptimizer, data, torch.randn, p=args.p, max_iter=args.niter) elif (model_type == 'DSWD'): loss = model.compute_lossDSWD(dis, disoptimizer, data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'DGSWD'): loss = model.compute_lossDGSWD(dis, disoptimizer, data, torch.randn, num_projection, transform_net, op_trannet, g_function, args.r, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'CRAMER'): loss = model.compute_loss_cramer(dis, disoptimizer, data, torch.randn) optimizer.zero_grad() total_loss += loss.item() loss.backward() optimizer.step() total_loss /= (batch_idx + 1) print("Epoch: " + str(epoch) + " Loss: " + str(total_loss)) if (epoch % 1 == 0 or epoch == args.epochs - 1): sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png", fixednoise, model.decoder, 64, image_size, num_chanel)
def main(): # train args parser = argparse.ArgumentParser( description='Disributional Sliced Wasserstein Autoencoder') parser.add_argument('--datadir', default='./', help='path to dataset') parser.add_argument('--outdir', default='./result', help='directory to output images') parser.add_argument('--batch-size', type=int, default=512, metavar='N', help='input batch size for training (default: 512)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)') parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', help='learning rate (default: 0.0005)') parser.add_argument( '--num-workers', type=int, default=16, metavar='N', help='number of dataloader workers if device is CPU (default: 16)') parser.add_argument('--seed', type=int, default=16, metavar='S', help='random seed (default: 16)') parser.add_argument('--g', type=str, default='circular', help='g') parser.add_argument('--num-projection', type=int, default=1000, help='number projection') parser.add_argument('--lam', type=float, default=1, help='Regularization strength') parser.add_argument('--p', type=int, default=2, help='Norm p') parser.add_argument('--niter', type=int, default=10, help='number of iterations') parser.add_argument('--r', type=float, default=1000, help='R') parser.add_argument('--latent-size', type=int, default=32, help='Latent size') parser.add_argument('--dataset', type=str, default='MNIST', help='(MNIST|FMNIST)') parser.add_argument( '--model-type', type=str, required=True, help= '(SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD|CRAMER|JCRAMER|SINKHORN|JSINKHORN)' ) args = parser.parse_args() torch.random.manual_seed(args.seed) if (args.g == 'circular'): g_function = circular_function model_type = args.model_type latent_size = args.latent_size num_projection = args.num_projection dataset = args.dataset model_dir = os.path.join(args.outdir, model_type) assert dataset in ['MNIST', 'FMNIST'] assert model_type in [ 'SWD', 'MSWD', 'DSWD', 'GSWD', 'DGSWD', 'JSWD', 'JMSWD', 'JDSWD', 'JGSWD', 'JDGSWD', 'CRAMER', 'JCRAMER', 'SINKHORN', 'JSINKHORN' ] if not (os.path.isdir(args.datadir)): os.makedirs(args.datadir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(model_dir)): os.makedirs(model_dir) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format( args.batch_size, args.epochs, args.lr, device.type)) # build train and test set data loaders if (dataset == 'MNIST'): image_size = 28 num_chanel = 1 train_loader = torch.utils.data.DataLoader( datasets.MNIST(args.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) test_loader = torch.utils.data.DataLoader(datasets.MNIST( args.datadir, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])), batch_size=64, shuffle=False, num_workers=args.num_workers) model = MnistAutoencoder(image_size=28, latent_size=args.latent_size, hidden_size=100, device=device).to(device) elif (dataset == 'FMNIST'): image_size = 28 num_chanel = 1 train_loader = torch.utils.data.DataLoader( datasets.FashionMNIST(args.datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor()])), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST( args.datadir, train=False, download=True, transform=transforms.Compose([transforms.ToTensor()])), batch_size=64, shuffle=False, num_workers=args.num_workers) model = MnistAutoencoder(image_size=28, latent_size=args.latent_size, hidden_size=100, device=device).to(device) if (model_type == 'DSWD' or model_type == 'DGSWD'): transform_net = TransformNet(28 * 28).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # train_net(28 * 28, 1000, transform_net, op_trannet) elif (model_type == 'JDSWD' or model_type == 'JDSWD2' or model_type == 'JDGSWD'): transform_net = TransformNet(args.latent_size + 28 * 28).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) fixednoise = torch.randn((64, latent_size)).to(device) for epoch in range(args.epochs): total_loss = 0.0 for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)): if (model_type == 'SWD'): loss = model.compute_loss_SWD(data, torch.randn, num_projection, p=args.p) elif (model_type == 'GSWD'): loss = model.compute_loss_GSWD(data, torch.randn, g_function, args.r, num_projection, p=args.p) elif (model_type == 'MSWD'): loss, v = model.compute_loss_MSWD(data, torch.randn, p=args.p, max_iter=args.niter) elif (model_type == 'DSWD'): loss = model.compute_lossDSWD(data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'DGSWD'): loss = model.compute_lossDGSWD(data, torch.randn, num_projection, transform_net, op_trannet, g_function, r=args.r, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'JSWD'): loss = model.compute_loss_JSWD(data, torch.randn, num_projection, p=args.p) elif (model_type == 'JGSWD'): loss = model.compute_loss_JGSWD(data, torch.randn, g_function, args.r, num_projection, p=args.p) elif (model_type == 'JDSWD'): loss = model.compute_lossJDSWD(data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'JDGSWD'): loss = model.compute_lossJDGSWD(data, torch.randn, num_projection, transform_net, op_trannet, g_function, r=args.r, p=args.p, max_iter=args.niter, lam=args.lam) elif (model_type == 'JMSWD'): loss, v = model.compute_loss_MSWD(data, torch.randn, p=args.p, max_iter=args.niter) elif (model_type == 'CRAMER'): loss = model.compute_loss_cramer(data, torch.randn) elif (model_type == 'JCRAMER'): loss = model.compute_loss_join_cramer(data, torch.randn) elif (model_type == 'SINKHORN'): loss = model.compute_wasserstein_vi_loss(data, torch.randn, n_iter=args.niter, p=args.p, e=1) elif (model_type == 'JSINKHORN'): loss = model.compute_join_wasserstein_vi_loss( data, torch.randn, n_iter=args.niter, p=args.p, e=1) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() total_loss /= (batch_idx + 1) print("Epoch: " + str(epoch) + " Loss: " + str(total_loss)) if (epoch % 1 == 0): model.eval() sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png", fixednoise, model.decoder, 64, image_size, num_chanel) if (model_type[0] == 'J'): for _, (input, y) in enumerate(test_loader, start=0): input = input.to(device) input = input.view(-1, image_size**2) reconstruct( model_dir + '/reconstruction_epoch_' + str(epoch) + ".png", input, model.encoder, model.decoder, image_size, num_chanel, device) break model.train()
def main(): # train args parser = argparse.ArgumentParser( description="Disributional Sliced Wasserstein Autoencoder") parser.add_argument("--datadir", default="./", help="path to dataset") parser.add_argument("--outdir", default="./result", help="directory to output images") parser.add_argument("--batch-size", type=int, default=512, metavar="N", help="input batch size for training (default: 512)") parser.add_argument("--epochs", type=int, default=200, metavar="N", help="number of epochs to train (default: 200)") parser.add_argument("--lr", type=float, default=0.0005, metavar="LR", help="learning rate (default: 0.0005)") parser.add_argument( "--num-workers", type=int, default=16, metavar="N", help="number of dataloader workers if device is CPU (default: 16)", ) parser.add_argument("--seed", type=int, default=16, metavar="S", help="random seed (default: 16)") parser.add_argument("--g", type=str, default="circular", help="g") parser.add_argument("--num-projection", type=int, default=1000, help="number projection") parser.add_argument("--lam", type=float, default=1, help="Regularization strength") parser.add_argument("--p", type=int, default=2, help="Norm p") parser.add_argument("--niter", type=int, default=10, help="number of iterations") parser.add_argument("--r", type=float, default=1000, help="R") parser.add_argument("--latent-size", type=int, default=32, help="Latent size") parser.add_argument("--hsize", type=int, default=100, help="Latent size") parser.add_argument("--dataset", type=str, default="MNIST", help="(CELEBA|CIFAR)") parser.add_argument("--model-type", type=str, required=True, help="(SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER|)") parser.add_argument("--cont", type=bool, help="") parser.add_argument("--dim", type=int, default=100, help="subspace size") parser.add_argument("--e", type=float, default=1000, help="R") args = parser.parse_args() torch.random.manual_seed(args.seed) if args.g == "circular": g_function = circular_function model_type = args.model_type latent_size = args.latent_size num_projection = args.num_projection dataset = args.dataset model_dir = os.path.join(args.outdir, model_type) assert dataset in ["CELEBA", "CIFAR", "LSUN"] assert model_type in ["SWD", "MSWD", "DSWD", "GSWD", "DGSWD", "MGSWNN"] if model_type == "SWD": model_dir = os.path.join(args.outdir, model_type + "_n" + str(num_projection)) elif model_type == "DSWD": model_dir = os.path.join( args.outdir, model_type + "_iter" + str(args.niter) + "_n" + str(num_projection) + "_lam" + str(args.lam)) elif model_type == "MSWD": model_dir = os.path.join(args.outdir, model_type) elif model_type == "MGSWNN": model_dir = os.path.join(args.outdir, model_type + "_size" + str(args.hsize)) elif model_type == "GSWD": model_dir = os.path.join( args.outdir, model_type + "_n" + str(num_projection) + "_" + args.g + str(args.r)) elif model_type == "DGSWD": model_dir = os.path.join( args.outdir, model_type + "_iter" + str(args.niter) + "_n" + str(num_projection) + "_lam" + str(args.lam) + "_" + args.g + str(args.r), ) print(model_dir) if not (os.path.isdir(args.datadir)): os.makedirs(args.datadir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(args.outdir)): os.makedirs(args.outdir) if not (os.path.isdir(model_dir)): os.makedirs(model_dir) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print("batch size {}\nepochs {}\nAdam lr {} \n using device {}\n".format( args.batch_size, args.epochs, args.lr, device.type)) if dataset == "CIFAR": from DCGANAE import Discriminator image_size = 64 num_chanel = 3 train_loader = torch.utils.data.DataLoader( datasets.CIFAR10( args.datadir, train=True, download=True, transform=transforms.Compose([ transforms.Resize(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), ), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, ) elif dataset == "LSUN": from DCGANAE import Discriminator image_size = 64 num_chanel = 3 train_loader = torch.utils.data.DataLoader( datasets.LSUN( args.datadir + "/lsun", classes=["bedroom_train"], transform=transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), ), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, ) elif dataset == "CELEBA": from DCGANAE import Discriminator image_size = 64 num_chanel = 3 dataset = CustomDataset( root=args.datadir + "/img_align_celeba", transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), ) # Create the dataloader train_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # test_loader = torch.utils.data.DataLoader( # dataset, batch_size=1000, shuffle=True, num_workers=args.num_workers, pin_memory=True # ) model = DCGANAE(image_size=64, latent_size=latent_size, num_chanel=3, hidden_chanels=64, device=device).to(device) dis = Discriminator(64, args.latent_size, 3, 64).to(device) disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999)) if model_type == "DSWD" or model_type == "DGSWD": transform_net = TransformNet(64 * 8 * 4 * 4).to(device) op_trannet = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999)) # train_net(64 * 8 * 4 * 4, 1000, transform_net, op_trannet) if model_type == "MGSWNN": gsw = GSW_NN(din=64 * 8 * 4 * 4, nofprojections=1, model_depth=3, num_filters=args.hsize, use_cuda=True) if model_type == "MSWD": gsw = GSW() optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) epoch_cont = 0 if args.cont: epoch_cont, modelstate, optimizerstate, tnetstate, optnetstate, disstate, opdistate = load_dmodel( model_dir) model.load_state_dict(modelstate) optimizer.load_state_dict(optimizerstate) dis.load_state_dict(disstate) disoptimizer.load_state_dict(opdistate) epoch_cont = epoch_cont + 1 print("Continue from epoch " + str(epoch_cont)) fixednoise = torch.randn((64, latent_size)).to(device) for epoch in range(epoch_cont, args.epochs): total_loss = 0.0 for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)): if model_type == "SWD": loss = model.compute_loss_SWD(dis, disoptimizer, data, torch.randn, num_projection, p=args.p) elif model_type == "GSWD": loss = model.compute_loss_GSWD(dis, disoptimizer, data, torch.randn, g_function, args.r, num_projection, p=args.p) elif model_type == "MGSWNN": loss = model.compute_loss_MGSWNN(dis, disoptimizer, data, torch.randn, gsw, p=args.p) elif model_type == "MSWD": loss = model.compute_loss_MSWD(dis, disoptimizer, data, torch.randn, gsw) elif model_type == "DSWD": loss = model.compute_lossDSWD( dis, disoptimizer, data, torch.randn, num_projection, transform_net, op_trannet, p=args.p, max_iter=args.niter, lam=args.lam, ) elif model_type == "DGSWD": loss = model.compute_lossDGSWD( dis, disoptimizer, data, torch.randn, num_projection, transform_net, op_trannet, g_function, args.r, p=args.p, max_iter=args.niter, lam=args.lam, ) optimizer.zero_grad() total_loss += loss.item() loss.backward() optimizer.step() total_loss /= batch_idx + 1 print("Epoch: " + str(epoch) + " Loss: " + str(total_loss)) save_dmodel(model, optimizer, dis, disoptimizer, None, None, epoch, model_dir) sampling(model_dir + "/sample_epoch_" + str(epoch) + ".png", fixednoise, model.decoder, 64, image_size, num_chanel) if epoch == args.epochs - 1: model.eval() sampling_eps(model_dir + "/sample_epoch_" + str(epoch), fixednoise, model.decoder, 64, image_size, num_chanel) model.train()