예제 #1
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Disributional Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir',
                        default='/user/HS229/xc00414/condor_examples/DSW/DSW/',
                        help='path to dataset')
    parser.add_argument(
        '--outdir',
        default='/user/HS229/xc00414/condor_examples/DSW/DSW/result',
        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--lrpsw',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate psw (default: 0.001)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=16,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=1,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=10,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--kappa', type=float, default=50, help='R')
    parser.add_argument('--k', type=int, default=10, help='R')
    parser.add_argument('--e', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--hsize', type=int, default=100, help='h size')
    parser.add_argument('--dim', type=int, default=100, help='subspace size')
    parser.add_argument('--dataset',
                        type=str,
                        default='MNIST',
                        help='(MNIST|FMNIST)')
    parser.add_argument(
        '--model-type',
        type=str,
        required=True,
        help='(ASWD|SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD)')
    args = parser.parse_args()

    #torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    assert dataset in ['MNIST']
    assert model_type in [
        'MGSWNN', 'JMGSWNN', 'SWD', 'MSWD', 'PSTW', 'PSW', 'MGSWD', 'DSWD',
        'GSWD', 'DGSWD', 'JSWD', 'JMSWD', 'JMGSWD', 'JDSWD', 'JGSWD', 'JDGSWD',
        'DGSWNN', 'JDGSWNN', 'GSWNN', 'JGSWNN', 'ASWD'
    ]
    if (model_type == 'SWD' or model_type == 'JSWD'):
        model_dir = os.path.join(args.outdir,
                                 model_type + '_n' + str(num_projection))
    elif model_type == 'GSWNN':
        model_dir = os.path.join(args.outdir,
                                 model_type + '_n' + str(num_projection))
    elif (model_type == 'GSWD' or model_type == 'JGSWD'):
        model_dir = os.path.join(
            args.outdir, model_type + '_n' + str(num_projection) + '_' +
            args.g + str(args.r))
    elif (model_type == 'DSWD' or model_type == 'JDSWD'
          or model_type == 'ASWD'):
        model_dir = os.path.join(
            args.outdir, model_type + '_iter' + str(args.niter) + '_n' +
            str(num_projection) + '_lam' + str(args.lam))
    elif (model_type == 'DGSWD' or model_type == 'JDGSWD'):
        model_dir = os.path.join(
            args.outdir, model_type + '_iter' + str(args.niter) + '_n' +
            str(num_projection) + '_lam' + str(args.lam) + '_' + args.g +
            str(args.r))
    elif (model_type == 'MSWD' or model_type == 'JMSWD'):
        model_dir = os.path.join(args.outdir, model_type)
    elif (model_type == 'MGSWNN' or model_type == 'JMGSWNN'):
        model_dir = os.path.join(args.outdir,
                                 model_type + '_size' + str(args.hsize))
    elif (model_type == 'MGSWD' or model_type == 'JMGSWD'):
        model_dir = os.path.join(args.outdir, model_type + '_' + args.g)
    elif (model_type == 'PSW' or model_type == 'JPSW'):
        model_dir = os.path.join(
            args.outdir, model_type + '_e' + str(args.e) + '_iter' +
            str(args.niter) + '_d' + str(args.dim)) + '_lr' + str(args.lrpsw)
    elif (model_type == 'PSTW' or model_type == 'JPSTW'):
        model_dir = os.path.join(
            args.outdir, model_type + '_iter' + str(args.niter) + '_d' +
            str(args.dim)) + '_lr' + str(args.lrpsw)
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))
    # build train and test set data loaders
    if (dataset == 'MNIST'):
        image_size = 28
        num_chanel = 1
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST(
            args.datadir,
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                  batch_size=10000,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)
        test_loader2 = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=False,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=64,
            shuffle=False,
            num_workers=args.num_workers)
        model = MnistAutoencoder(image_size=28,
                                 latent_size=args.latent_size,
                                 hidden_size=100,
                                 device=device).to(device)
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # op_trannet = optim.Adam(transform_net.parameters(), lr=1e-4)
        # train_net(28 * 28, 1000, transform_net, op_trannet)
    elif (model_type == 'JDSWD' or model_type == 'JDSWD2'
          or model_type == 'JDGSWD'):
        transform_net = TransformNet(args.latent_size + 28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet)
    elif (model_type == 'ASWD'):
        phi = Mapping(28 * 28).to(device)
        phi_op = optim.Adam(phi.parameters(), lr=0.0005, betas=(0.9, 0.999))
    if (model_type == 'MGSWNN'):
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if (model_type == 'GSWNN' or model_type == 'DGSWNN'):
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=num_projection,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if (model_type == 'JMGSWNN'):
        gsw = GSW_NN(din=28 * 28 + 32,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if (model_type == 'MSWD' or model_type == 'JMSWD'):
        gsw = GSW()
    if (model_type == 'MGSWD'):
        theta = torch.randn((1, 784), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        #opt_theta = optim.Adam(transform_net.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if (model_type == 'JMGSWD'):
        theta = torch.randn((1, 784 + 32), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        opt_theta = torch.optim.Adam(transform_net.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    fixednoise = torch.randn((64, latent_size)).to(device)
    ite = 0
    wd_list = []
    swd_list = []
    save_idx = str(time.time()).split('.')
    save_idx = save_idx[0] + save_idx[1]

    for epoch in range(args.epochs):
        total_loss = 0.0
        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            tic = time.time()
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif (model_type == 'ASWD'):
                loss = model.compute_lossTGSWD(data,
                                               torch.randn,
                                               num_projection,
                                               phi,
                                               phi_op,
                                               p=2,
                                               max_iter=args.niter,
                                               lam=args.lam,
                                               net_type='fc')
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=2)
            elif (model_type == 'GSWNN'):
                loss = model.compute_loss_GSWNN(data, torch.randn, gsw, p=2)
            elif (model_type == 'DGSWNN'):
                loss = model.compute_loss_DGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 args.lam,
                                                 args.lr,
                                                 p=2)
            elif (model_type == 'PSW'):
                loss = model.compute_loss_PSW(data,
                                              torch.randn,
                                              pdim=args.dim,
                                              p=args.p,
                                              n_iter=args.niter,
                                              n_iter_sinkhorn=args.niters,
                                              e=args.e,
                                              lr=args.lrpsw)
            elif (model_type == 'PSTW'):
                loss = model.compute_loss_PSTW(data,
                                               torch.randn,
                                               pdim=args.dim,
                                               p=args.p,
                                               n_iter=args.niter,
                                               lr=args.lrpsw)
            elif (model_type == 'MGSWNN'):
                loss = model.compute_loss_MGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 p=args.p)
            elif (model_type == 'JMGSWNN'):
                loss = model.compute_loss_JMGSWNN(data,
                                                  torch.randn,
                                                  gsw,
                                                  args.niter,
                                                  p=args.p)
            elif (model_type == 'MSWD'):
                loss = model.compute_loss_MSWD(data, torch.randn, gsw,
                                               args.niter)
            elif (model_type == 'MGSWD'):
                loss = model.compute_loss_MGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                p=args.p,
                                                max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=args.lam)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               r=args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'JSWD'):
                loss = model.compute_loss_JSWD(data,
                                               torch.randn,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'JGSWD'):
                loss = model.compute_loss_JGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                num_projection,
                                                p=args.p)
            elif (model_type == 'JDSWD'):
                loss = model.compute_lossJDSWD(data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'JDGSWD'):
                loss = model.compute_lossJDGSWD(data,
                                                torch.randn,
                                                num_projection,
                                                transform_net,
                                                op_trannet,
                                                g_function,
                                                r=args.r,
                                                p=args.p,
                                                max_iter=args.niter,
                                                lam=args.lam)
            elif (model_type == 'JMSWD'):
                loss = model.compute_loss_JMSWD(data, torch.randn, gsw,
                                                args.niter)
            elif (model_type == 'JMGSWD'):
                loss = model.compute_loss_JMGSWD(data,
                                                 torch.randn,
                                                 theta,
                                                 opt_theta,
                                                 g_function,
                                                 args.r,
                                                 p=args.p,
                                                 max_iter=args.niter)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            toc = time.time()
            print('execution_time:', toc - tic, model_type)
            total_loss += loss.item()
            if (ite % 100 == 0):
                model.eval()
                for _, (input, y) in enumerate(test_loader, start=0):
                    fixednoise_wd = torch.randn(
                        (10000, latent_size)).to(device)
                    data = input.to(device)
                    data = data.view(data.shape[0], -1)
                    fake = model.decoder(fixednoise_wd)
                    wd_list.append(
                        compute_true_Wasserstein(data.to('cpu'),
                                                 fake.to('cpu')))
                    swd_list.append(
                        sliced_wasserstein_distance(data, fake, 10000).item())
                    print("Iter:" + str(ite) + " WD: " + str(wd_list[-1]))
                    np.savetxt(model_dir + "/wd" + save_idx + ".csv",
                               wd_list,
                               delimiter=",")
                    print("Iter:" + str(ite) + " SWD: " + str(swd_list[-1]))
                    np.savetxt(model_dir + "/swd" + save_idx + ".csv",
                               swd_list,
                               delimiter=",")
                    break
                model.train()
            ite = ite + 1
        total_loss /= (batch_idx + 1)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0):
            model.eval()
            sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png",
                     fixednoise, model.decoder, 64, image_size, num_chanel)
            if (model_type[0] == 'J'):
                for _, (input, y) in enumerate(test_loader2, start=0):
                    input = input.to(device)
                    input = input.view(-1, image_size**2)
                    reconstruct(
                        model_dir + '/reconstruction_epoch_' + str(epoch) +
                        ".png", input, model.encoder, model.decoder,
                        image_size, num_chanel, device)
                    break
            model.train()
        save_dmodel(model, optimizer, None, None, None, None, epoch, model_dir)
        if (epoch == args.epochs - 1):
            model.eval()
            sampling_eps(model_dir + '/sample_epoch_' + str(epoch), fixednoise,
                         model.decoder, 64, image_size, num_chanel)
            model.train()
예제 #2
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description="Disributional Sliced Wasserstein Autoencoder")
    parser.add_argument("--datadir", default="./", help="path to dataset")
    parser.add_argument("--outdir",
                        default="./result",
                        help="directory to output images")
    parser.add_argument("--batch-size",
                        type=int,
                        default=512,
                        metavar="N",
                        help="input batch size for training (default: 512)")
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        metavar="N",
                        help="number of epochs to train (default: 200)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.0005,
                        metavar="LR",
                        help="learning rate (default: 0.0005)")
    parser.add_argument(
        "--num-workers",
        type=int,
        default=16,
        metavar="N",
        help="number of dataloader workers if device is CPU (default: 16)",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=16,
                        metavar="S",
                        help="random seed (default: 16)")
    parser.add_argument("--g", type=str, default="circular", help="g")
    parser.add_argument("--num-projection",
                        type=int,
                        default=1000,
                        help="number projection")
    parser.add_argument("--lam",
                        type=float,
                        default=1,
                        help="Regularization strength")
    parser.add_argument("--p", type=int, default=2, help="Norm p")
    parser.add_argument("--niter",
                        type=int,
                        default=10,
                        help="number of iterations")
    parser.add_argument("--r", type=float, default=1000, help="R")
    parser.add_argument("--kappa", type=float, default=50, help="R")
    parser.add_argument("--k", type=int, default=10, help="R")
    parser.add_argument("--e", type=float, default=1000, help="R")
    parser.add_argument("--latent-size",
                        type=int,
                        default=32,
                        help="Latent size")
    parser.add_argument("--hsize", type=int, default=100, help="h size")
    parser.add_argument("--dataset",
                        type=str,
                        default="MNIST",
                        help="(MNIST|FMNIST)")
    parser.add_argument(
        "--model-type",
        type=str,
        required=True,
        help="(SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD)")
    args = parser.parse_args()

    torch.random.manual_seed(args.seed)
    if args.g == "circular":
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    assert dataset in ["MNIST"]
    assert model_type in [
        "MGSWNN",
        "JMGSWNN",
        "SWD",
        "MSWD",
        "MGSWD",
        "DSWD",
        "GSWD",
        "DGSWD",
        "JSWD",
        "JMSWD",
        "JMGSWD",
        "JDSWD",
        "JGSWD",
        "JDGSWD",
        "DGSWNN",
        "JDGSWNN",
        "GSWNN",
        "JGSWNN",
    ]
    if model_type == "SWD" or model_type == "JSWD":
        model_dir = os.path.join(args.outdir,
                                 model_type + "_n" + str(num_projection))
    elif model_type == "GSWD" or model_type == "JGSWD":
        model_dir = os.path.join(
            args.outdir, model_type + "_n" + str(num_projection) + "_" +
            args.g + str(args.r))
    elif model_type == "DSWD" or model_type == "JDSWD":
        model_dir = os.path.join(
            args.outdir, model_type + "_iter" + str(args.niter) + "_n" +
            str(num_projection) + "_lam" + str(args.lam))
    elif model_type == "DGSWD" or model_type == "JDGSWD":
        model_dir = os.path.join(
            args.outdir,
            model_type + "_iter" + str(args.niter) + "_n" +
            str(num_projection) + "_lam" + str(args.lam) + "_" + args.g +
            str(args.r),
        )
    elif model_type == "MSWD" or model_type == "JMSWD":
        model_dir = os.path.join(args.outdir, model_type)
    elif model_type == "MGSWNN" or model_type == "JMGSWNN":
        model_dir = os.path.join(args.outdir,
                                 model_type + "_size" + str(args.hsize))
    elif model_type == "MGSWD" or model_type == "JMGSWD":
        model_dir = os.path.join(args.outdir, model_type + "_" + args.g)
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("batch size {}\nepochs {}\nAdam lr {} \n using device {}\n".format(
        args.batch_size, args.epochs, args.lr, device.type))
    # build train and test set data loaders
    if dataset == "MNIST":
        image_size = 28
        num_chanel = 1
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=False,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=10000,
            shuffle=False,
            num_workers=args.num_workers,
        )
        test_loader2 = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=False,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=64,
            shuffle=False,
            num_workers=args.num_workers,
        )
        model = MnistAutoencoder(image_size=28,
                                 latent_size=args.latent_size,
                                 hidden_size=100,
                                 device=device).to(device)
    if model_type == "DSWD" or model_type == "DGSWD":
        transform_net = TransformNet(28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # op_trannet = optim.Adam(transform_net.parameters(), lr=1e-4)
        # train_net(28 * 28, 1000, transform_net, op_trannet)
    elif model_type == "JDSWD" or model_type == "JDSWD2" or model_type == "JDGSWD":
        transform_net = TransformNet(args.latent_size + 28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet)
    if model_type == "MGSWNN":
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "GSWNN" or model_type == "DGSWNN":
        gsw = GSW_NN(din=28 * 28,
                     nofprojections=num_projection,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "JMGSWNN":
        gsw = GSW_NN(din=28 * 28 + 32,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "MSWD" or model_type == "JMSWD":
        gsw = GSW()
    if model_type == "MGSWD":
        theta = torch.randn((1, 784), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        opt_theta = optim.Adam(transform_net.parameters(),
                               lr=args.lr,
                               betas=(0.5, 0.999))
    if model_type == "JMGSWD":
        theta = torch.randn((1, 784 + 32), device=device, requires_grad=True)
        theta.data = theta.data / torch.sqrt(torch.sum(theta.data**2, dim=1))
        opt_theta = torch.optim.Adam(transform_net.parameters(),
                                     lr=args.lr,
                                     betas=(0.5, 0.999))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    fixednoise = torch.randn((64, latent_size)).to(device)
    ite = 0
    wd_list = []
    swd_list = []
    for epoch in range(args.epochs):
        total_loss = 0.0
        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if model_type == "SWD":
                loss = model.compute_loss_SWD(data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif model_type == "GSWD":
                loss = model.compute_loss_GSWD(data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=2)
            elif model_type == "GSWNN":
                loss = model.compute_loss_GSWNN(data, torch.randn, gsw, p=2)
            elif model_type == "DGSWNN":
                loss = model.compute_loss_DGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 args.lam,
                                                 args.lr,
                                                 p=2)
            elif model_type == "MGSWNN":
                loss = model.compute_loss_MGSWNN(data,
                                                 torch.randn,
                                                 gsw,
                                                 args.niter,
                                                 p=args.p)
            elif model_type == "JMGSWNN":
                loss = model.compute_loss_JMGSWNN(data,
                                                  torch.randn,
                                                  gsw,
                                                  args.niter,
                                                  p=args.p)
            elif model_type == "MSWD":
                loss = model.compute_loss_MSWD(data, torch.randn, gsw,
                                               args.niter)
            elif model_type == "MGSWD":
                loss = model.compute_loss_MGSWD(data,
                                                torch.randn,
                                                theta,
                                                opt_theta,
                                                g_function,
                                                args.r,
                                                args.lr2,
                                                p=args.p,
                                                max_iter=args.niter)
            elif model_type == "DSWD":
                loss = model.compute_lossDSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "DGSWD":
                loss = model.compute_lossDGSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    g_function,
                    r=args.r,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "JSWD":
                loss = model.compute_loss_JSWD(data,
                                               torch.randn,
                                               num_projection,
                                               p=args.p)
            elif model_type == "JGSWD":
                loss = model.compute_loss_JGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                num_projection,
                                                p=args.p)
            elif model_type == "JDSWD":
                loss = model.compute_lossJDSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "JDGSWD":
                loss = model.compute_lossJDGSWD(
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    g_function,
                    r=args.r,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "JMSWD":
                loss = model.compute_loss_JMSWD(data, torch.randn, gsw,
                                                args.niter)
            elif model_type == "JMGSWD":
                loss = model.compute_loss_JMGSWD(data,
                                                 torch.randn,
                                                 theta,
                                                 opt_theta,
                                                 g_function,
                                                 args.r,
                                                 p=args.p,
                                                 max_iter=args.niter)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if ite % 100 == 0:
                model.eval()
                for _, (input, y) in enumerate(test_loader, start=0):
                    fixednoise_wd = torch.randn(
                        (10000, latent_size)).to(device)
                    data = input.to(device)
                    data = data.view(data.shape[0], -1)
                    fake = model.decoder(fixednoise_wd)
                    wd_list.append(
                        compute_true_Wasserstein(data.to("cpu"),
                                                 fake.to("cpu")))
                    swd_list.append(
                        sliced_wasserstein_distance(data, fake, 10000).item())
                    print("Iter:" + str(ite) + " WD: " + str(wd_list[-1]))
                    np.savetxt(model_dir + "/wd.csv", wd_list, delimiter=",")
                    print("Iter:" + str(ite) + " SWD: " + str(swd_list[-1]))
                    np.savetxt(model_dir + "/swd.csv", swd_list, delimiter=",")
                    break
                model.train()
            ite = ite + 1
        total_loss /= batch_idx + 1
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if epoch % 1 == 0:
            model.eval()
            sampling(
                model_dir + "/sample_epoch_" + str(epoch) + ".png",
                fixednoise,
                model.decoder,
                64,
                image_size,
                num_chanel,
            )
            if model_type[0] == "J":
                for _, (input, y) in enumerate(test_loader2, start=0):
                    input = input.to(device)
                    input = input.view(-1, image_size**2)
                    reconstruct(
                        model_dir + "/reconstruction_epoch_" + str(epoch) +
                        ".png",
                        input,
                        model.encoder,
                        model.decoder,
                        image_size,
                        num_chanel,
                        device,
                    )
                    break
            model.train()
        save_dmodel(model, optimizer, None, None, None, None, epoch, model_dir)
        if epoch == args.epochs - 1:
            model.eval()
            sampling_eps(model_dir + "/sample_epoch_" + str(epoch), fixednoise,
                         model.decoder, 64, image_size, num_chanel)
            model.train()
예제 #3
0
파일: main.py 프로젝트: xiongjiechen/ASWD-1
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Augmented Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir', default='./', help='path to dataset')
    parser.add_argument('--outdir',
                        default='./result/',
                        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=11,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=0.5,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=5,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR',
                        help='(CELEBA|CIFAR)')
    parser.add_argument('--model-type',
                        type=str,
                        required=True,
                        help='(ASWD|SWD|MSWD|DSWD|GSWD|)')
    parser.add_argument('--gpu', type=str, required=False, default=0)
    args = parser.parse_args()
    torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    model_dir = os.path.join(args.outdir, model_type)
    assert dataset in ['CELEBA', 'CIFAR']
    assert model_type in ['ASWD', 'SWD', 'MSWD', 'DSWD', 'GSWD']
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:" + str(args.gpu) if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))

    if (dataset == 'CIFAR'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                args.datadir,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.ToTensor(),
                    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)

    elif (dataset == 'CELEBA'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        dataset = CustomDataset(
            root=args.datadir + 'img_align_celeba',
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
        # Create the dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)

    model = DCGANAE(image_size=64,
                    latent_size=latent_size,
                    num_chanel=3,
                    hidden_chanels=64,
                    device=device).to(device)
    #model=nn.DataParallel(model)
    #model.to(device)
    dis = Discriminator(64, args.latent_size, 3, 64).to(device)
    disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(64 * 8 * 4 * 4).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=0.0005,
                                betas=(0.5, 0.999))
        train_net(64 * 8 * 4 * 4,
                  1000,
                  transform_net,
                  op_trannet,
                  device=device)

    if model_type == 'ASWD':
        phi = Mapping(64 * 8 * 4 * 4).to(device)
        phi_op = optim.Adam(phi.parameters(), lr=0.001, betas=(0.5, 0.999))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    epoch_cont = 0
    generated_sample_number = 64
    fixednoise = torch.randn((generated_sample_number, latent_size)).to(device)
    loss_recorder = []
    generated_sample_size = 64
    W2_recorder = np.zeros([41, 40])
    save_idx = str(time.time()).split('.')
    save_idx = save_idx[0] + save_idx[1]
    path_0 = model_dir + '/' + args.dataset + '/fid/' + save_idx
    os.mkdir(path_0)
    if args.dataset == 'CIFAR':
        interval_ = 10
        fid_stats_file = args.datadir + '/fid_stats_cifar10_train.npz'
    else:
        interval_ = 5
        fid_stats_file = args.datadir + '/fid_stats_celeba.npz'
    fid_recorder = np.zeros(args.epochs // interval_ + 1)
    for epoch in range(epoch_cont, args.epochs):
        total_loss = 0.0

        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p,
                                              epoch=epoch,
                                              batch_idx=batch_idx)
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'MSWD'):
                loss, v = model.compute_loss_MSWD(dis,
                                                  disoptimizer,
                                                  data,
                                                  torch.randn,
                                                  p=args.p,
                                                  max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=10)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=10)
            elif (model_type == 'SINKHORN'):
                loss = model.compute_loss_sinkhorn(dis,
                                                   disoptimizer,
                                                   data,
                                                   torch.randn,
                                                   p=2,
                                                   n_iter=100,
                                                   e=100)
            elif (model_type == 'CRAMER'):
                loss = model.compute_loss_cramer(dis, disoptimizer, data,
                                                 torch.randn)
            elif model_type == 'ASWD':
                loss = model.compute_lossASWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              phi,
                                              phi_op,
                                              p=2,
                                              max_iter=args.niter,
                                              lam=args.lam,
                                              epoch=epoch,
                                              batch_idx=batch_idx)
            optimizer.zero_grad()
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        if epoch == 0 or (epoch + 1) % interval_ == 0:
            path_1 = path_0 + '/' + str('%03d' % epoch)
            os.mkdir(path_1)
            for j in range(20):
                fixednoise_ = torch.randn((1000, 32)).to(device)
                imgs = model.decoder(fixednoise_)

                for i, img in enumerate(imgs):
                    img = img.transpose(0, -1).transpose(
                        0, 1).cpu().detach().numpy()
                    img = (img * 255).astype(np.uint8)
                    imageio.imwrite(
                        path_1 + '/' + args.model_type + '_' +
                        str(args.num_projection) + '_' + str('%03d' % epoch) +
                        '_' + str(1000 * j + i) + '.png', img)
            fid_value = calculate_fid_given_paths(
                [path_1 + '/', fid_stats_file], 50, True, 2048)
            fid_recorder[(epoch + 1) // interval_] = fid_value
            np.save(
                path_0 + '/fid_recorder_' + 'np_' + str(num_projection) +
                '.npy', fid_recorder)
            print('fid score:', fid_value)
            os.system("rm -rf " + path_1)
        total_loss /= (batch_idx + 1)
        loss_recorder.append(total_loss)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0 or epoch == args.epochs - 1):
            sampling(
                model_dir + '/' + args.dataset + '/sample_epoch_' +
                str(epoch) + ".png", fixednoise, model.decoder,
                generated_sample_number, generated_sample_size, num_chanel)
            torch.save(
                model.state_dict(), args.outdir + args.dataset + '/' +
                model_type + '_' + str(args.batch_size) + '_' +
                str(num_projection) + '_' + str(latent_size) + '_model.pth')
            torch.save(
                dis.state_dict(),
                args.outdir + args.dataset + '/' + model_type + '_' +
                str(args.batch_size) + '_' + str(num_projection) + '_' +
                str(latent_size) + '_discriminator.pth')
    np.save(
        args.outdir + args.dataset + '/' + model_type + '_' +
        str(args.batch_size) + '_' + str(num_projection) + '_' +
        str(latent_size) + '_loss.npy', loss_recorder)
예제 #4
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Disributional Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir', default='./', help='path to dataset')
    parser.add_argument('--outdir',
                        default='./result',
                        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=16,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=1,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=10,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--dataset',
                        type=str,
                        default='MNIST',
                        help='(CELEBA|CIFAR)')
    parser.add_argument('--model-type',
                        type=str,
                        required=True,
                        help='(SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER|)')
    args = parser.parse_args()
    torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    model_dir = os.path.join(args.outdir, model_type)
    assert dataset in ['CELEBA', 'CIFAR']
    assert model_type in ['SWD', 'MSWD', 'DSWD', 'GSWD', 'DGSWD', 'CRAMER']
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))

    if (dataset == 'CIFAR'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                args.datadir,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.ToTensor(),
                    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)

    elif (dataset == 'CELEBA'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        dataset = CustomDataset(
            root=args.datadir + '/img_align_celeba',
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
        # Create the dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)

    model = DCGANAE(image_size=64,
                    latent_size=latent_size,
                    num_chanel=3,
                    hidden_chanels=64,
                    device=device).to(device)
    dis = Discriminator(64, args.latent_size, 3, 64).to(device)
    disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(64 * 8 * 4 * 4).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        train_net(64 * 8 * 4 * 4, 1000, transform_net, op_trannet)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    epoch_cont = 0

    fixednoise = torch.randn((64, latent_size)).to(device)

    for epoch in range(epoch_cont, args.epochs):
        total_loss = 0.0

        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'MSWD'):
                loss, v = model.compute_loss_MSWD(dis,
                                                  disoptimizer,
                                                  data,
                                                  torch.randn,
                                                  p=args.p,
                                                  max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=args.lam)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'CRAMER'):
                loss = model.compute_loss_cramer(dis, disoptimizer, data,
                                                 torch.randn)

            optimizer.zero_grad()
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        total_loss /= (batch_idx + 1)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0 or epoch == args.epochs - 1):
            sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png",
                     fixednoise, model.decoder, 64, image_size, num_chanel)
예제 #5
0
파일: mnist.py 프로젝트: HaoWen6588/DSW
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Disributional Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir', default='./', help='path to dataset')
    parser.add_argument('--outdir',
                        default='./result',
                        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=16,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=1,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=10,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--dataset',
                        type=str,
                        default='MNIST',
                        help='(MNIST|FMNIST)')
    parser.add_argument(
        '--model-type',
        type=str,
        required=True,
        help=
        '(SWD|MSWD|DSWD|GSWD|DGSWD|JSWD|JMSWD|JDSWD|JGSWD|JDGSWD|CRAMER|JCRAMER|SINKHORN|JSINKHORN)'
    )
    args = parser.parse_args()

    torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    model_dir = os.path.join(args.outdir, model_type)
    assert dataset in ['MNIST', 'FMNIST']
    assert model_type in [
        'SWD', 'MSWD', 'DSWD', 'GSWD', 'DGSWD', 'JSWD', 'JMSWD', 'JDSWD',
        'JGSWD', 'JDGSWD', 'CRAMER', 'JCRAMER', 'SINKHORN', 'JSINKHORN'
    ]
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))
    # build train and test set data loaders
    if (dataset == 'MNIST'):
        image_size = 28
        num_chanel = 1
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose(
                               [transforms.ToTensor()])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)
        test_loader = torch.utils.data.DataLoader(datasets.MNIST(
            args.datadir,
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                  batch_size=64,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)
        model = MnistAutoencoder(image_size=28,
                                 latent_size=args.latent_size,
                                 hidden_size=100,
                                 device=device).to(device)
    elif (dataset == 'FMNIST'):
        image_size = 28
        num_chanel = 1
        train_loader = torch.utils.data.DataLoader(
            datasets.FashionMNIST(args.datadir,
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose(
                                      [transforms.ToTensor()])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)
        test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(
            args.datadir,
            train=False,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()])),
                                                  batch_size=64,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)
        model = MnistAutoencoder(image_size=28,
                                 latent_size=args.latent_size,
                                 hidden_size=100,
                                 device=device).to(device)
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(28 * 28, 1000, transform_net, op_trannet)
    elif (model_type == 'JDSWD' or model_type == 'JDSWD2'
          or model_type == 'JDGSWD'):
        transform_net = TransformNet(args.latent_size + 28 * 28).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(args.latent_size + 28 * 28, 1000, transform_net, op_trannet)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    fixednoise = torch.randn((64, latent_size)).to(device)
    for epoch in range(args.epochs):
        total_loss = 0.0
        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'MSWD'):
                loss, v = model.compute_loss_MSWD(data,
                                                  torch.randn,
                                                  p=args.p,
                                                  max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=args.lam)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               r=args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'JSWD'):
                loss = model.compute_loss_JSWD(data,
                                               torch.randn,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'JGSWD'):
                loss = model.compute_loss_JGSWD(data,
                                                torch.randn,
                                                g_function,
                                                args.r,
                                                num_projection,
                                                p=args.p)
            elif (model_type == 'JDSWD'):
                loss = model.compute_lossJDSWD(data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'JDGSWD'):
                loss = model.compute_lossJDGSWD(data,
                                                torch.randn,
                                                num_projection,
                                                transform_net,
                                                op_trannet,
                                                g_function,
                                                r=args.r,
                                                p=args.p,
                                                max_iter=args.niter,
                                                lam=args.lam)
            elif (model_type == 'JMSWD'):
                loss, v = model.compute_loss_MSWD(data,
                                                  torch.randn,
                                                  p=args.p,
                                                  max_iter=args.niter)
            elif (model_type == 'CRAMER'):
                loss = model.compute_loss_cramer(data, torch.randn)
            elif (model_type == 'JCRAMER'):
                loss = model.compute_loss_join_cramer(data, torch.randn)
            elif (model_type == 'SINKHORN'):
                loss = model.compute_wasserstein_vi_loss(data,
                                                         torch.randn,
                                                         n_iter=args.niter,
                                                         p=args.p,
                                                         e=1)
            elif (model_type == 'JSINKHORN'):
                loss = model.compute_join_wasserstein_vi_loss(
                    data, torch.randn, n_iter=args.niter, p=args.p, e=1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        total_loss /= (batch_idx + 1)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0):
            model.eval()
            sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png",
                     fixednoise, model.decoder, 64, image_size, num_chanel)
            if (model_type[0] == 'J'):
                for _, (input, y) in enumerate(test_loader, start=0):
                    input = input.to(device)
                    input = input.view(-1, image_size**2)
                    reconstruct(
                        model_dir + '/reconstruction_epoch_' + str(epoch) +
                        ".png", input, model.encoder, model.decoder,
                        image_size, num_chanel, device)
                    break
            model.train()
예제 #6
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description="Disributional Sliced Wasserstein Autoencoder")
    parser.add_argument("--datadir", default="./", help="path to dataset")
    parser.add_argument("--outdir",
                        default="./result",
                        help="directory to output images")
    parser.add_argument("--batch-size",
                        type=int,
                        default=512,
                        metavar="N",
                        help="input batch size for training (default: 512)")
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        metavar="N",
                        help="number of epochs to train (default: 200)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.0005,
                        metavar="LR",
                        help="learning rate (default: 0.0005)")
    parser.add_argument(
        "--num-workers",
        type=int,
        default=16,
        metavar="N",
        help="number of dataloader workers if device is CPU (default: 16)",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=16,
                        metavar="S",
                        help="random seed (default: 16)")
    parser.add_argument("--g", type=str, default="circular", help="g")
    parser.add_argument("--num-projection",
                        type=int,
                        default=1000,
                        help="number projection")
    parser.add_argument("--lam",
                        type=float,
                        default=1,
                        help="Regularization strength")
    parser.add_argument("--p", type=int, default=2, help="Norm p")
    parser.add_argument("--niter",
                        type=int,
                        default=10,
                        help="number of iterations")
    parser.add_argument("--r", type=float, default=1000, help="R")
    parser.add_argument("--latent-size",
                        type=int,
                        default=32,
                        help="Latent size")
    parser.add_argument("--hsize", type=int, default=100, help="Latent size")
    parser.add_argument("--dataset",
                        type=str,
                        default="MNIST",
                        help="(CELEBA|CIFAR)")
    parser.add_argument("--model-type",
                        type=str,
                        required=True,
                        help="(SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER|)")
    parser.add_argument("--cont", type=bool, help="")
    parser.add_argument("--dim", type=int, default=100, help="subspace size")
    parser.add_argument("--e", type=float, default=1000, help="R")
    args = parser.parse_args()
    torch.random.manual_seed(args.seed)
    if args.g == "circular":
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    model_dir = os.path.join(args.outdir, model_type)
    assert dataset in ["CELEBA", "CIFAR", "LSUN"]
    assert model_type in ["SWD", "MSWD", "DSWD", "GSWD", "DGSWD", "MGSWNN"]
    if model_type == "SWD":
        model_dir = os.path.join(args.outdir,
                                 model_type + "_n" + str(num_projection))
    elif model_type == "DSWD":
        model_dir = os.path.join(
            args.outdir, model_type + "_iter" + str(args.niter) + "_n" +
            str(num_projection) + "_lam" + str(args.lam))
    elif model_type == "MSWD":
        model_dir = os.path.join(args.outdir, model_type)
    elif model_type == "MGSWNN":
        model_dir = os.path.join(args.outdir,
                                 model_type + "_size" + str(args.hsize))
    elif model_type == "GSWD":
        model_dir = os.path.join(
            args.outdir, model_type + "_n" + str(num_projection) + "_" +
            args.g + str(args.r))
    elif model_type == "DGSWD":
        model_dir = os.path.join(
            args.outdir,
            model_type + "_iter" + str(args.niter) + "_n" +
            str(num_projection) + "_lam" + str(args.lam) + "_" + args.g +
            str(args.r),
        )
    print(model_dir)
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print("batch size {}\nepochs {}\nAdam lr {} \n using device {}\n".format(
        args.batch_size, args.epochs, args.lr, device.type))

    if dataset == "CIFAR":
        from DCGANAE import Discriminator

        image_size = 64
        num_chanel = 3
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                args.datadir,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            ),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )
    elif dataset == "LSUN":
        from DCGANAE import Discriminator

        image_size = 64
        num_chanel = 3
        train_loader = torch.utils.data.DataLoader(
            datasets.LSUN(
                args.datadir + "/lsun",
                classes=["bedroom_train"],
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.CenterCrop(64),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
            ),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        )

    elif dataset == "CELEBA":
        from DCGANAE import Discriminator

        image_size = 64
        num_chanel = 3
        dataset = CustomDataset(
            root=args.datadir + "/img_align_celeba",
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]),
        )
        # Create the dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)
        # test_loader = torch.utils.data.DataLoader(
        #     dataset, batch_size=1000, shuffle=True, num_workers=args.num_workers, pin_memory=True
        # )

    model = DCGANAE(image_size=64,
                    latent_size=latent_size,
                    num_chanel=3,
                    hidden_chanels=64,
                    device=device).to(device)
    dis = Discriminator(64, args.latent_size, 3, 64).to(device)
    disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if model_type == "DSWD" or model_type == "DGSWD":
        transform_net = TransformNet(64 * 8 * 4 * 4).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        # train_net(64 * 8 * 4 * 4, 1000, transform_net, op_trannet)
    if model_type == "MGSWNN":
        gsw = GSW_NN(din=64 * 8 * 4 * 4,
                     nofprojections=1,
                     model_depth=3,
                     num_filters=args.hsize,
                     use_cuda=True)
    if model_type == "MSWD":
        gsw = GSW()
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    epoch_cont = 0
    if args.cont:
        epoch_cont, modelstate, optimizerstate, tnetstate, optnetstate, disstate, opdistate = load_dmodel(
            model_dir)
        model.load_state_dict(modelstate)
        optimizer.load_state_dict(optimizerstate)
        dis.load_state_dict(disstate)
        disoptimizer.load_state_dict(opdistate)
        epoch_cont = epoch_cont + 1
        print("Continue from epoch " + str(epoch_cont))
    fixednoise = torch.randn((64, latent_size)).to(device)

    for epoch in range(epoch_cont, args.epochs):
        total_loss = 0.0

        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if model_type == "SWD":
                loss = model.compute_loss_SWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif model_type == "GSWD":
                loss = model.compute_loss_GSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=args.p)
            elif model_type == "MGSWNN":
                loss = model.compute_loss_MGSWNN(dis,
                                                 disoptimizer,
                                                 data,
                                                 torch.randn,
                                                 gsw,
                                                 p=args.p)
            elif model_type == "MSWD":
                loss = model.compute_loss_MSWD(dis, disoptimizer, data,
                                               torch.randn, gsw)
            elif model_type == "DSWD":
                loss = model.compute_lossDSWD(
                    dis,
                    disoptimizer,
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            elif model_type == "DGSWD":
                loss = model.compute_lossDGSWD(
                    dis,
                    disoptimizer,
                    data,
                    torch.randn,
                    num_projection,
                    transform_net,
                    op_trannet,
                    g_function,
                    args.r,
                    p=args.p,
                    max_iter=args.niter,
                    lam=args.lam,
                )
            optimizer.zero_grad()
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        total_loss /= batch_idx + 1
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))
        save_dmodel(model, optimizer, dis, disoptimizer, None, None, epoch,
                    model_dir)
        sampling(model_dir + "/sample_epoch_" + str(epoch) + ".png",
                 fixednoise, model.decoder, 64, image_size, num_chanel)
        if epoch == args.epochs - 1:
            model.eval()
            sampling_eps(model_dir + "/sample_epoch_" + str(epoch), fixednoise,
                         model.decoder, 64, image_size, num_chanel)
            model.train()