Ejemplo n.º 1
0
    def compute_loss_SWD(self,
                         discriminator,
                         optimizer,
                         minibatch,
                         rand_dist,
                         num_projection,
                         p=2):
        label = torch.full((minibatch.shape[0], ), 1, device=self.device)
        criterion = nn.BCELoss()
        data = minibatch.to(self.device)
        z_prior = rand_dist((data.shape[0], self.latent_size)).to(self.device)
        data_fake = self.decoder(z_prior)
        y_data, data = discriminator(data)
        errD_real = criterion(y_data, label)
        optimizer.zero_grad()
        errD_real.backward(retain_graph=True)
        optimizer.step()
        y_fake, data_fake = discriminator(data_fake)
        label.fill_(0)
        errD_fake = criterion(y_fake, label)
        optimizer.zero_grad()
        errD_fake.backward(retain_graph=True)
        optimizer.step()
        _swd = sliced_wasserstein_distance(data.view(data.shape[0], -1),
                                           data_fake.view(data.shape[0], -1),
                                           num_projection, p, self.device)

        return _swd
Ejemplo n.º 2
0
    def compute_loss_SWD(self, minibatch, rand_dist, num_projection, p=2):
        data = minibatch.to(self.device)
        z_prior = rand_dist((data.shape[0], self.latent_size)).to(self.device)
        data_fake = self.decoder(z_prior)

        _swd = sliced_wasserstein_distance(data.view(data.shape[0], -1),
                                           data_fake.view(data.shape[0], -1),
                                           num_projection, p, self.device)

        return _swd
Ejemplo n.º 3
0
                one_hot = one_hot_identity[inds].float().to(device)

                r = encoder(one_hot, tiles)
                fake_image = decoder(
                    r, one_hot_identity[all_inds].float().to(device))
                # loss = torch.mean(torch.tensor([wasserstein_distance(x, y) for (x,y) in zip(fake_image.view(64,3*4*4).detach().cpu(), tiled_ground_truth.view(64,3*4*4).detach().cpu())], requires_grad=True))

                # loss = criterion(fake_image, weird_ground_truth)
                fake_image = fake_image.view(64, 4 * 4 * 3)
                weird_ground_truth = weird_ground_truth.view(64, 4 * 4 * 3)
                loss = 0.0
                for tile in range(64):
                    loss += sliced_wasserstein_distance(
                        fake_image[tile].unsqueeze(0),
                        weird_ground_truth[tile].unsqueeze(0),
                        num_projections=5,
                        device=device)

                loss.backward()
                optimizer.step()
            # plot_grad_flow(encoder.named_parameters())
            # plot_grad_flow(decoder.named_parameters())

        # for t in range(5):
        # optimizer_critic.zero_grad()
        # for image, target in zip(ground_truth_images, targets):
        #     num_tiles = np.random.randint(4, 64)
        #     tiles, inds, _, all_inds, tiled_ground_truth = utils.get_tiles(image, num_tiles=num_tiles)

        #     image = image.to(device)
Ejemplo n.º 4
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Disributional Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir',
                        default='/user/HS229/xc00414/condor_examples/DSW/DSW/',
                        help='path to dataset')
    parser.add_argument(
        '--outdir',
        default='/user/HS229/xc00414/condor_examples/DSW/DSW/result',
        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--lrpsw',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate psw (default: 0.001)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=16,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=1,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=10,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--kappa', type=float, default=50, help='R')
    parser.add_argument('--k', type=int, default=10, help='R')
    parser.add_argument('--e', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--hsize', type=int, default=100, help='h size')
    parser.add_argument('--dim', type=int, default=100, help='subspace size')
    parser.add_argument('--dataset',
                        type=str,
                        default='MNIST',
                        help='(MNIST|FMNIST)')
    parser.add_argument(
        '--model-type',
        type=str,
        required=True,
        help='(ASWD|SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD)')
    args = parser.parse_args()

    #torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    assert dataset in ['MNIST']
    assert model_type in [
        'MGSWNN', 'JMGSWNN', 'SWD', 'MSWD', 'PSTW', 'PSW', 'MGSWD', 'DSWD',
        'GSWD', 'DGSWD', 'JSWD', 'JMSWD', 'JMGSWD', 'JDSWD', 'JGSWD', 'JDGSWD',
        'DGSWNN', 'JDGSWNN', 'GSWNN', 'JGSWNN', 'ASWD'
    ]
    if (model_type == 'SWD' or model_type == 'JSWD'):
        model_dir = os.path.join(args.outdir,
                                 model_type + '_n' + str(num_projection))
    elif model_type == 'GSWNN':
        model_dir = os.path.join(args.outdir,
                                 model_type + '_n' + str(num_projection))
    elif (model_type == 'GSWD' or model_type == 'JGSWD'):
        model_dir = os.path.join(
            args.outdir, model_type + '_n' + str(num_projection) + '_' +
            args.g + str(args.r))
    elif (model_type == 'DSWD' or model_type == 'JDSWD'
          or model_type == 'ASWD'):
        model_dir = os.path.join(
            args.outdir, model_type + '_iter' + str(args.niter) + '_n' +
            str(num_projection) + '_lam' + str(args.lam))
    elif (model_type == 'DGSWD' or model_type == 'JDGSWD'):
        model_dir = os.path.join(
            args.outdir, model_type + '_iter' + str(args.niter) + '_n' +
            str(num_projection) + '_lam' + str(args.lam) + '_' + args.g +
            str(args.r))
    elif (model_type == 'MSWD' or model_type == 'JMSWD'):
        model_dir = os.path.join(args.outdir, model_type)
    elif (model_type == 'MGSWNN' or model_type == 'JMGSWNN'):
        model_dir = os.path.join(args.outdir,
                                 model_type + '_size' + str(args.hsize))
    elif (model_type == 'MGSWD' or model_type == 'JMGSWD'):
        model_dir = os.path.join(args.outdir, model_type + '_' + args.g)
    elif (model_type == 'PSW' or model_type == 'JPSW'):
        model_dir = os.path.join(
            args.outdir, model_type + '_e' + str(args.e) + '_iter' +
            str(args.niter) + '_d' + str(args.dim)) + '_lr' + str(args.lrpsw)
    elif (model_type == 'PSTW' or model_type == 'JPSTW'):
        model_dir = os.path.join(
            args.outdir, model_type + '_iter' + str(args.niter) + '_d' +
            str(args.dim)) + '_lr' + str(args.lrpsw)
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))
    # build train and test set data loaders
    if (dataset == 'MNIST'):
        image_size = 28
        num_chanel = 1
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST(
            args.datadir,
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                  batch_size=10000,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)
        test_loader2 = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=False,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=64,
            shuffle=False,
            num_workers=args.num_workers)
        model = MnistAutoencoder(image_size=28,
                                 latent_size=args.latent_size,
                                 hidden_size=100,
                                 device=device).to(device)
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # op_trannet = optim.Adam(transform_net.parameters(), lr=1e-4)
        # train_net(28 * 28, 1000, transform_net, op_trannet)
    elif (model_type == 'JDSWD' or model_type == 'JDSWD2'
          or model_type == 'JDGSWD'):
        transform_net = TransformNet(args.latent_size + 28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet)
    elif (model_type == 'ASWD'):
        phi = Mapping(28 * 28).to(device)
        phi_op = optim.Adam(phi.parameters(), lr=0.0005, betas=(0.9, 0.999))
    if (model_type == 'MGSWNN'):
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if (model_type == 'GSWNN' or model_type == 'DGSWNN'):
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=num_projection,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if (model_type == 'JMGSWNN'):
        gsw = GSW_NN(din=28 * 28 + 32,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if (model_type == 'MSWD' or model_type == 'JMSWD'):
        gsw = GSW()
    if (model_type == 'MGSWD'):
        theta = torch.randn((1, 784), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        #opt_theta = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if (model_type == 'JMGSWD'):
        theta = torch.randn((1, 784 + 32), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        opt_theta = torch.optim.Adam(transform_net.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    fixednoise = torch.randn((64, latent_size)).to(device)
    ite = 0
    wd_list = []
    swd_list = []
    save_idx = str(time.time()).split('.')
    save_idx = save_idx[0] + save_idx[1]

    for epoch in range(args.epochs):
        total_loss = 0.0
        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            tic = time.time()
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif (model_type == 'ASWD'):
                loss = model.compute_lossTGSWD(data,
                                               torch.randn,
                                               num_projection,
                                               phi,
                                               phi_op,
                                               p=2,
                                               max_iter=args.niter,
                                               lam=args.lam,
                                               net_type='fc')
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=2)
            elif (model_type == 'GSWNN'):
                loss = model.compute_loss_GSWNN(data, torch.randn, gsw, p=2)
            elif (model_type == 'DGSWNN'):
                loss = model.compute_loss_DGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 args.lam,
                                                 args.lr,
                                                 p=2)
            elif (model_type == 'PSW'):
                loss = model.compute_loss_PSW(data,
                                              torch.randn,
                                              pdim=args.dim,
                                              p=args.p,
                                              n_iter=args.niter,
                                              n_iter_sinkhorn=args.niters,
                                              e=args.e,
                                              lr=args.lrpsw)
            elif (model_type == 'PSTW'):
                loss = model.compute_loss_PSTW(data,
                                               torch.randn,
                                               pdim=args.dim,
                                               p=args.p,
                                               n_iter=args.niter,
                                               lr=args.lrpsw)
            elif (model_type == 'MGSWNN'):
                loss = model.compute_loss_MGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 p=args.p)
            elif (model_type == 'JMGSWNN'):
                loss = model.compute_loss_JMGSWNN(data,
                                                  torch.randn,
                                                  gsw,
                                                  args.niter,
                                                  p=args.p)
            elif (model_type == 'MSWD'):
                loss = model.compute_loss_MSWD(data, torch.randn, gsw,
                                               args.niter)
            elif (model_type == 'MGSWD'):
                loss = model.compute_loss_MGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                p=args.p,
                                                max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=args.lam)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               r=args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'JSWD'):
                loss = model.compute_loss_JSWD(data,
                                               torch.randn,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'JGSWD'):
                loss = model.compute_loss_JGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                num_projection,
                                                p=args.p)
            elif (model_type == 'JDSWD'):
                loss = model.compute_lossJDSWD(data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'JDGSWD'):
                loss = model.compute_lossJDGSWD(data,
                                                torch.randn,
                                                num_projection,
                                                transform_net,
                                                op_trannet,
                                                g_function,
                                                r=args.r,
                                                p=args.p,
                                                max_iter=args.niter,
                                                lam=args.lam)
            elif (model_type == 'JMSWD'):
                loss = model.compute_loss_JMSWD(data, torch.randn, gsw,
                                                args.niter)
            elif (model_type == 'JMGSWD'):
                loss = model.compute_loss_JMGSWD(data,
                                                 torch.randn,
                                                 theta,
                                                 opt_theta,
                                                 g_function,
                                                 args.r,
                                                 p=args.p,
                                                 max_iter=args.niter)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            toc = time.time()
            print('execution_time:', toc - tic, model_type)
            total_loss += loss.item()
            if (ite % 100 == 0):
                model.eval()
                for _, (input, y) in enumerate(test_loader, start=0):
                    fixednoise_wd = torch.randn(
                        (10000, latent_size)).to(device)
                    data = input.to(device)
                    data = data.view(data.shape[0], -1)
                    fake = model.decoder(fixednoise_wd)
                    wd_list.append(
                        compute_true_Wasserstein(data.to('cpu'),
                                                 fake.to('cpu')))
                    swd_list.append(
                        sliced_wasserstein_distance(data, fake, 10000).item())
                    print("Iter:" + str(ite) + " WD: " + str(wd_list[-1]))
                    np.savetxt(model_dir + "/wd" + save_idx + ".csv",
                               wd_list,
                               delimiter=",")
                    print("Iter:" + str(ite) + " SWD: " + str(swd_list[-1]))
                    np.savetxt(model_dir + "/swd" + save_idx + ".csv",
                               swd_list,
                               delimiter=",")
                    break
                model.train()
            ite = ite + 1
        total_loss /= (batch_idx + 1)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0):
            model.eval()
            sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png",
                     fixednoise, model.decoder, 64, image_size, num_chanel)
            if (model_type[0] == 'J'):
                for _, (input, y) in enumerate(test_loader2, start=0):
                    input = input.to(device)
                    input = input.view(-1, image_size**2)
                    reconstruct(
                        model_dir + '/reconstruction_epoch_' + str(epoch) +
                        ".png", input, model.encoder, model.decoder,
                        image_size, num_chanel, device)
                    break
            model.train()
        save_dmodel(model, optimizer, None, None, None, None, epoch, model_dir)
        if (epoch == args.epochs - 1):
            model.eval()
            sampling_eps(model_dir + '/sample_epoch_' + str(epoch), fixednoise,
                         model.decoder, 64, image_size, num_chanel)
            model.train()
Ejemplo n.º 5
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description="Disributional Sliced Wasserstein Autoencoder")
    parser.add_argument("--datadir", default="./", help="path to dataset")
    parser.add_argument("--outdir",
                        default="./result",
                        help="directory to output images")
    parser.add_argument("--batch-size",
                        type=int,
                        default=512,
                        metavar="N",
                        help="input batch size for training (default: 512)")
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        metavar="N",
                        help="number of epochs to train (default: 200)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.0005,
                        metavar="LR",
                        help="learning rate (default: 0.0005)")
    parser.add_argument(
        "--num-workers",
        type=int,
        default=16,
        metavar="N",
        help="number of dataloader workers if device is CPU (default: 16)",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=16,
                        metavar="S",
                        help="random seed (default: 16)")
    parser.add_argument("--g", type=str, default="circular", help="g")
    parser.add_argument("--num-projection",
                        type=int,
                        default=1000,
                        help="number projection")
    parser.add_argument("--lam",
                        type=float,
                        default=1,
                        help="Regularization strength")
    parser.add_argument("--p", type=int, default=2, help="Norm p")
    parser.add_argument("--niter",
                        type=int,
                        default=10,
                        help="number of iterations")
    parser.add_argument("--r", type=float, default=1000, help="R")
    parser.add_argument("--kappa", type=float, default=50, help="R")
    parser.add_argument("--k", type=int, default=10, help="R")
    parser.add_argument("--e", type=float, default=1000, help="R")
    parser.add_argument("--latent-size",
                        type=int,
                        default=32,
                        help="Latent size")
    parser.add_argument("--hsize", type=int, default=100, help="h size")
    parser.add_argument("--dataset",
                        type=str,
                        default="MNIST",
                        help="(MNIST|FMNIST)")
    parser.add_argument(
        "--model-type",
        type=str,
        required=True,
        help="(SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD)")
    args = parser.parse_args()

    torch.random.manual_seed(args.seed)
    if args.g == "circular":
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    assert dataset in ["MNIST"]
    assert model_type in [
        "MGSWNN",
        "JMGSWNN",
        "SWD",
        "MSWD",
        "MGSWD",
        "DSWD",
        "GSWD",
        "DGSWD",
        "JSWD",
        "JMSWD",
        "JMGSWD",
        "JDSWD",
        "JGSWD",
        "JDGSWD",
        "DGSWNN",
        "JDGSWNN",
        "GSWNN",
        "JGSWNN",
    ]
    if model_type == "SWD" or model_type == "JSWD":
        model_dir = os.path.join(args.outdir,
                                 model_type + "_n" + str(num_projection))
    elif model_type == "GSWD" or model_type == "JGSWD":
        model_dir = os.path.join(
            args.outdir, model_type + "_n" + str(num_projection) + "_" +
            args.g + str(args.r))
    elif model_type == "DSWD" or model_type == "JDSWD":
        model_dir = os.path.join(
            args.outdir, model_type + "_iter" + str(args.niter) + "_n" +
            str(num_projection) + "_lam" + str(args.lam))
    elif model_type == "DGSWD" or model_type == "JDGSWD":
        model_dir = os.path.join(
            args.outdir,
            model_type + "_iter" + str(args.niter) + "_n" +
            str(num_projection) + "_lam" + str(args.lam) + "_" + args.g +
            str(args.r),
        )
    elif model_type == "MSWD" or model_type == "JMSWD":
        model_dir = os.path.join(args.outdir, model_type)
    elif model_type == "MGSWNN" or model_type == "JMGSWNN":
        model_dir = os.path.join(args.outdir,
                                 model_type + "_size" + str(args.hsize))
    elif model_type == "MGSWD" or model_type == "JMGSWD":
        model_dir = os.path.join(args.outdir, model_type + "_" + args.g)
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("batch size {}\nepochs {}\nAdam lr {} \n using device {}\n".format(
        args.batch_size, args.epochs, args.lr, device.type))
    # build train and test set data loaders
    if dataset == "MNIST":
        image_size = 28
        num_chanel = 1
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=False,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=10000,
            shuffle=False,
            num_workers=args.num_workers,
        )
        test_loader2 = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=False,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=64,
            shuffle=False,
            num_workers=args.num_workers,
        )
        model = MnistAutoencoder(image_size=28,
                                 latent_size=args.latent_size,
                                 hidden_size=100,
                                 device=device).to(device)
    if model_type == "DSWD" or model_type == "DGSWD":
        transform_net = TransformNet(28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # op_trannet = optim.Adam(transform_net.parameters(), lr=1e-4)
        # train_net(28 * 28, 1000, transform_net, op_trannet)
    elif model_type == "JDSWD" or model_type == "JDSWD2" or model_type == "JDGSWD":
        transform_net = TransformNet(args.latent_size + 28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet)
    if model_type == "MGSWNN":
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "GSWNN" or model_type == "DGSWNN":
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=num_projection,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "JMGSWNN":
        gsw = GSW_NN(din=28 * 28 + 32,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "MSWD" or model_type == "JMSWD":
        gsw = GSW()
    if model_type == "MGSWD":
        theta = torch.randn((1, 784), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        opt_theta = optim.Adam(transform_net.parameters(),
                               lr=args.lr,
                               betas=(0.5, 0.999))
    if model_type == "JMGSWD":
        theta = torch.randn((1, 784 + 32), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        opt_theta = torch.optim.Adam(transform_net.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    fixednoise = torch.randn((64, latent_size)).to(device)
    ite = 0
    wd_list = []
    swd_list = []
    for epoch in range(args.epochs):
        total_loss = 0.0
        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if model_type == "SWD":
                loss = model.compute_loss_SWD(data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif model_type == "GSWD":
                loss = model.compute_loss_GSWD(data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=2)
            elif model_type == "GSWNN":
                loss = model.compute_loss_GSWNN(data, torch.randn, gsw, p=2)
            elif model_type == "DGSWNN":
                loss = model.compute_loss_DGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 args.lam,
                                                 args.lr,
                                                 p=2)
            elif model_type == "MGSWNN":
                loss = model.compute_loss_MGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 p=args.p)
            elif model_type == "JMGSWNN":
                loss = model.compute_loss_JMGSWNN(data,
                                                  torch.randn,
                                                  gsw,
                                                  args.niter,
                                                  p=args.p)
            elif model_type == "MSWD":
                loss = model.compute_loss_MSWD(data, torch.randn, gsw,
                                               args.niter)
            elif model_type == "MGSWD":
                loss = model.compute_loss_MGSWD(data,
                                                torch.randn,
                                                theta,
                                                opt_theta,
                                                g_function,
                                                args.r,
                                                args.lr2,
                                                p=args.p,
                                                max_iter=args.niter)
            elif model_type == "DSWD":
                loss = model.compute_lossDSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "DGSWD":
                loss = model.compute_lossDGSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    g_function,
                    r=args.r,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "JSWD":
                loss = model.compute_loss_JSWD(data,
                                               torch.randn,
                                               num_projection,
                                               p=args.p)
            elif model_type == "JGSWD":
                loss = model.compute_loss_JGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                num_projection,
                                                p=args.p)
            elif model_type == "JDSWD":
                loss = model.compute_lossJDSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "JDGSWD":
                loss = model.compute_lossJDGSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    g_function,
                    r=args.r,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "JMSWD":
                loss = model.compute_loss_JMSWD(data, torch.randn, gsw,
                                                args.niter)
            elif model_type == "JMGSWD":
                loss = model.compute_loss_JMGSWD(data,
                                                 torch.randn,
                                                 theta,
                                                 opt_theta,
                                                 g_function,
                                                 args.r,
                                                 p=args.p,
                                                 max_iter=args.niter)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if ite % 100 == 0:
                model.eval()
                for _, (input, y) in enumerate(test_loader, start=0):
                    fixednoise_wd = torch.randn(
                        (10000, latent_size)).to(device)
                    data = input.to(device)
                    data = data.view(data.shape[0], -1)
                    fake = model.decoder(fixednoise_wd)
                    wd_list.append(
                        compute_true_Wasserstein(data.to("cpu"),
                                                 fake.to("cpu")))
                    swd_list.append(
                        sliced_wasserstein_distance(data, fake, 10000).item())
                    print("Iter:" + str(ite) + " WD: " + str(wd_list[-1]))
                    np.savetxt(model_dir + "/wd.csv", wd_list, delimiter=",")
                    print("Iter:" + str(ite) + " SWD: " + str(swd_list[-1]))
                    np.savetxt(model_dir + "/swd.csv", swd_list, delimiter=",")
                    break
                model.train()
            ite = ite + 1
        total_loss /= batch_idx + 1
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if epoch % 1 == 0:
            model.eval()
            sampling(
                model_dir + "/sample_epoch_" + str(epoch) + ".png",
                fixednoise,
                model.decoder,
                64,
                image_size,
                num_chanel,
            )
            if model_type[0] == "J":
                for _, (input, y) in enumerate(test_loader2, start=0):
                    input = input.to(device)
                    input = input.view(-1, image_size**2)
                    reconstruct(
                        model_dir + "/reconstruction_epoch_" + str(epoch) +
                        ".png",
                        input,
                        model.encoder,
                        model.decoder,
                        image_size,
                        num_chanel,
                        device,
                    )
                    break
            model.train()
        save_dmodel(model, optimizer, None, None, None, None, epoch, model_dir)
        if epoch == args.epochs - 1:
            model.eval()
            sampling_eps(model_dir + "/sample_epoch_" + str(epoch), fixednoise,
                         model.decoder, 64, image_size, num_chanel)
            model.train()