示例#1
0
文件: main.py 项目: LiamMa/DL_AS3
def main(args):
    train_iter, valid_iter, test_iter = get_data_loader('data', args.batch_size, model=args.model)
    # ---------- Try to load model first ----------
    model = None
    if args.load_model:
        try:
            model = utils.model_load('P3_{}.pt'.format(args.model))
        except:
            logger.info('Loading model failed, will train the model first.')
            args.load_model = False
    # ---------- GAN related ----------
    if args.model == 'GAN':
        if not args.load_model or args.mode == 'train':
            if not model:
                model = GAN(N_LATENT)
            model.to(dev)
            train_gan(model, train_iter, test_iter, args.num_epochs,
                      G_update_iterval=args.G_update_interval, test_interval=args.test_interval,
                      save_model=args.save_model)
        model.eval()

        if args.mode == 'test':
            test_all(model, test_iter, args.batch_size, N_LATENT, model_name='GAN')

        if args.mode == 'gen':
            latent = torch.randn(size=(1000, N_LATENT), device=dev)
            generate_images(model.generator, latent, save_path='sample/GAN/samples')

    # ---------- VAE related ----------
    else:
        if args.mode == 'train' or not args.load_model:
            model = VAE(N_LATENT)
            model.to(dev)
            train_vae(model, train_iter, test_iter, args.num_epochs,
                      test_interval=args.test_interval, save_model=args.save_model)
        model.eval()

        if args.mode == 'test':
            test_all(model, test_iter, args.batch_size, N_LATENT, model_name='VAE')

        if args.mode == 'gen':
            # n_imgs = 0
            # latent = []
            # with torch.no_grad():
            #     for batch in test_iter:
            #         X = batch[0].to(dev)
            #         n_imgs += X.shape[0]
            #         latent.append(model.reparam(*model.encode(X)))
            #
            #         if n_imgs > 1000:
            #             break
            #     latent = torch.cat(latent, dim=0)[:1000]
            latent = torch.randn(size=(1000, N_LATENT), device=dev)
            generate_images(model.decoder, latent, save_path='sample/VAE/samples')
示例#2
0
        "--eps",
        type=float,
        default=1e-1,
        help="Perturbation value to the latent when evaluating")
    parser.add_argument("--sample_dir",
                        type=str,
                        default="samples",
                        help="Directory containing samples for"
                        "evaluation")

    # get the arguments
    args = parser.parse_args()
    args.device = torch.device(
        "cuda") if cuda.is_available() else torch.device('cpu')
    # check for cuda
    device = torch.device("cuda") if cuda.is_available() else torch.device(
        'cpu')
    args.device = device

    # load the dataset
    train, valid, test = get_data_loader(args.data_dir, args.batch_size)

    # Create model. Load or train depending on choice
    model = VAE(batch_size=args.batch_size, dimz=args.dimz).to(args.device)
    if args.t:
        train_model(model, train, valid, args.save_path)
    else:
        model.load_state_dict(torch.load(args.load_path))
        model.eval()
        evaluation(model)
示例#3
0
            filename = f"images/vae/fid/img/{i * 100 + j:03d}.png"
            torchvision.utils.save_image(image, filename, normalize=True)


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Running on {device}")

    vae = VAE()
    vae = vae.to(device)

    optimizer = optim.Adam(vae.parameters(), lr=3e-4)

    running_loss = 0

    trainloader, validloader, testloader = get_data_loader("svhn", 64)

    try:
        vae.load_state_dict(torch.load('q3_vae_save.pth', map_location=device))
        print('----Using saved model----')

    except FileNotFoundError:
        for epoch in range(5):

            print(f"------- EPOCH {epoch} --------")

            for i, (x, _) in enumerate(trainloader):
                vae.train()
                optimizer.zero_grad()

                x = x.to(device)
示例#4
0
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

# generate grid of sample images
def gen_samples(i, G):
    noise = torch.randn(64, NOISE_DIM)
    if use_cuda:
        noise = noise.cuda(gpu)
    noisev = autograd.Variable(noise)
    samples = G(noisev)
    samples = samples.view(-1, 3, 32, 32)
    torchvision.utils.save_image(samples, './gan_data/samples/sample_{}.png'.format(i), nrow=8, padding=2)

# Dataset iterator
directory = "svhn/"
train_loader, valid_loader, test_loader = classify_svhn.get_data_loader(directory, BATCH_SIZE)

for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        if i == 0 or i % (DISC_ITERS + 1): # update discriminator
            D.zero_grad()
            real_data, _ = data
            dim = real_data.size(0)

            if use_cuda:
                real_data = real_data.cuda(gpu)
            real_data = autograd.Variable(real_data)

            D_real = D(real_data)
            D_real = D_real.mean()
            D_real.backward(mone)
示例#5
0
import torchvision
import os

parser = argparse.ArgumentParser(description='PyTorch model')
parser.add_argument('--eval_mode',
                    type=str,
                    default='Train',
                    help='eval mode to use: Train or Test')

args = parser.parse_args()
#import data
directory = "svhn/"
model_directory = "vae/"
batch_size = 32
torch.manual_seed(1111)
train_loader, valid_loader, test_loader = data.get_data_loader(
    directory, batch_size)
num_epochs = 100

if not os.path.exists(model_directory):
    print("creating VAE directory")
    os.mkdir(model_directory)
    os.mkdir(model_directory + '/imgs')
    os.mkdir(model_directory + '/models')

MSE = nn.MSELoss(reduction='sum').cuda()

generator = Generator(latent_size=100)

model = VAE(latent_size=100)
if torch.cuda.is_available():
    model.cuda()
示例#6
0
                    generated = G(z)
                    generated_score = D(generated)
                    loss_g = -generated_score.mean()
                    loss_g.backward()
                    adam_g.step()

            print("Epoch:", epoch, "; Discriminator loss:", loss_d.item(),
                  "; Generator loss:", loss_g.item())
            save_image(generated,
                       os.path.join(SAMPLES_DIR, "gan-norm",
                                    str(epoch) + ".png"),
                       normalize=True,
                       nrow=2)
    finally:
        checkpoint = {
            "d_state_dict": D.state_dict(),
            "g_state_dict": G.state_dict(),
            "adam_d_state_dict": adam_d.state_dict(),
            "adam_g_state_dict": adam_g.state_dict(),
            "epoch": epoch,
            "loss_d": loss_d,
            "loss_g": loss_g
        }
        torch.save(checkpoint, "gan-q3-norm.tar")


if __name__ == "__main__":
    dataloader, *_ = get_data_loader("svhn", 32)
    train(Discriminator(), Generator(), dataloader)
    train(VAE(), None, dataloader, 3)
示例#7
0


if __name__ == "__main__":

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

	model = VAE().to(device)
	optimizer = optim.Adam(model.parameters(), lr=3e-4)

	###Training###

	#n_epochs = 50

	#Load data
	train_loader, valid_loader, test_loader = svhn.get_data_loader("svhn", 32)

	#Train + val
	#for epoch in range(n_epochs):
	#	train(epoch, train_loader)
	#	eval(epoch, valid_loader)

	#	with torch.no_grad():
			#Generate a batch of images using current parameters 
			#Sample z from prior p(z) = N(0,1)
	#		sample = torch.randn(16, 100).to(device)
	#		sample = model.decode(sample)
	#		save_image(sample.view(16, 3, 32, 32),
	#				   'results/sample_' + str(epoch) + '.png', normalize=True)