コード例 #1
0
def get_recons(loader, hidden_size=256):

    model = VectorQuantizedVAE(3, hidden_size, 512).to(DEVICE)

    if hidden_size < 256:
        ckpt = torch.load(
            "./models/imagenet/hs_{}/best.pt".format(hidden_size))
    else:
        ckpt = torch.load("./models/imagenet/best.pt".format(hidden_size))

    model.load_state_dict(ckpt)

    args = type('', (), {})()
    args.device = DEVICE

    gen_img, _ = next(iter(loader))

    # grid = make_grid(gen_img.cpu(), nrow=8)
    # torchvision.utils.save_image(grid, "hs_{}_recons.png".format(hidden_size))
    #exit()

    reconstruction = generate_samples(gen_img, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8)

    return grid
コード例 #2
0
def main(args):
    writer = SummaryWriter("./logs/{0}".format(args.output_folder))
    save_filename = "./models/{0}".format(args.output_folder)

    # Define the train, valid & test datasets
    all_dataset = TripletWhaleDataset(args.data_folder, min_instance_count=10)
    from torch.utils.data import random_split

    k = len(all_dataset)
    train_s = int(k * 0.6)
    test_s = int(k * 0.2)
    valid_s = k - train_s - test_s
    train_dataset, test_dataset, valid_dataset = random_split(
        all_dataset, (train_s, test_s, valid_s)
    )
    num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _, fixed_ids = next(iter(test_loader))[:3]
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image("original", fixed_grid, 0)
    n_classes = len(all_dataset.index_to_class)

    vae = VectorQuantizedVAE(num_channels, args.hidden_size, args.k)
    vae.load_state_dict(torch.load(args.model_path))
    encoder = vae.encoder.eval().to(args.device)
    model = AffineClassifier(all_dataset, encoder, n_classes).to(args.device).train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_loss = -1.0
    for epoch in range(args.num_epochs):
        with torch.autograd.set_detect_anomaly(True):
            train(train_loader, model, optimizer, args, writer)
        loss = test(valid_loader, model, args, writer)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open("{0}/best.pt".format(save_filename), "wb") as f:
                torch.save(model.state_dict(), f)
        with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f:
            torch.save(model.state_dict(), f)
コード例 #3
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset == 'whales':
        # Define the train, valid & test datasets
        all_dataset = WhaleDataset(args.data_folder, image_transformation=BASIC_IMAGE_T)
        from torch.utils.data import random_split
        k = len(all_dataset)
        train_s = int(k * .6)
        test_s = int(k * .2)
        valid_s = k - train_s - test_s
        train_dataset, test_dataset, valid_dataset = random_split(all_dataset, (train_s, test_s, valid_s))
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
        batch_size=args.batch_size, shuffle=False, drop_last=True,
        num_workers=args.num_workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
        batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _, fixed_ids = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    vae = VectorQuantizedVAE(num_channels, args.hidden_size, args.k)
    vae.load_state_dict(torch.load(args.model_path))
    encoder = vae.encoder.eval().to(args.device)
    affine_resample = AffineCropper(all_dataset, args.hidden_size, encoder).to(args.device)
    generator = affine_resample.generator.train()
    optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, fixed_ids, affine_resample, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('cropped', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        with torch.autograd.set_detect_anomaly(True):
            train(train_loader, encoder, affine_resample, optimizer, args, writer)
        loss = test(valid_loader, encoder, affine_resample, args, writer)

        reconstruction = generate_samples(fixed_images, fixed_ids, affine_resample, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
        writer.add_image('cropped', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(generator.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(generator.state_dict(), f)
コード例 #4
0
def get_model(ae_type, hidden_size, k, num_channels):
    CKPT_DIR = "models/imagenet/hs_32_4/best.pt"  #.format(hidden_size)
    model = VectorQuantizedVAE(num_channels, hidden_size, k)
    imgnetclassifier = ImgnetClassifier(model, hidden_size)

    ckpt = torch.load(CKPT_DIR)
    model.load_state_dict(ckpt)

    return imgnetclassifier
コード例 #5
0
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    discriminators = {}

    if args.recons_loss == "gan":
        recons_disc = Discriminator(input_dim, args.img_res,
                                    args.input_type).to(args.device)
        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators, True)
        # if args.weights == "init" and epoch==1:
        # 	epoch+=1
        # 	break

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
コード例 #6
0
ファイル: imgnet.py プロジェクト: sidwa/ae_thesis
def get_model(ae_type, hidden_size, k, num_channels):
	if ae_type == "vqvae":
		if hidden_size==256:
			CKPT_DIR = "models/imagenet/best.pt" 
		elif hidden_size==128:
			CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
		else:
			CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
		model = VectorQuantizedVAE(num_channels, hidden_size, k)
		imgnetclassifier = ImgnetClassifier(model, hidden_size)
	elif ae_type == "vae":
		CKPT_DIR = "models/imagenet_vae.pt"
		model = VAE(num_channels, hidden_size, 4096)
		imgnetclassifier = ImgnetClassifier(model, 4)

	ckpt = torch.load(CKPT_DIR)
	model.load_state_dict(ckpt)

	return imgnetclassifier
コード例 #7
0
ファイル: cifar.py プロジェクト: sidwa/ae_thesis
def get_model(model, hidden_size, k, num_channels, resolution, num_classes):

    if model == "vqvae":
        if hidden_size == 256:
            CKPT_DIR = "models/imagenet/best.pt"
        elif hidden_size == 128:
            CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
        else:
            CKPT_DIR = "models/imagenet/hs_{}/best.pt".format(hidden_size)
            #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size)
        model = VectorQuantizedVAE(num_channels, hidden_size, k)
    elif model == "vae":
        CKPT_DIR = f"models/imagenet_hs_128_{hidden_size}_vae.pt"
        model = VAE(num_channels, hidden_size, hidden_size)

    elif model == "aae":
        CKPT_DIR = f"models/aae/imagenet_hs_32_{hidden_size}/best.pt"
        model = AAE(32, num_channels, hidden_size)
    else:
        model = None

    imgnetclassifier = ImgnetClassifier(model, hidden_size, resolution)

    if hidden_size > 0:
        ckpt = torch.load(CKPT_DIR)
        model.load_state_dict(ckpt)

    return imgnetclassifier
コード例 #8
0
def get_model(model,
              hidden_size,
              num_channels,
              resolution,
              enc_type,
              dec_type,
              num_classes,
              k=512):

    CKPT_DIR = f"models/{model}_{args.recons_loss}/{args.train_dataset}/depth_{enc_type}_{dec_type}_hs_{args.img_res}_{hidden_size}/best.pt"
    if model == "vqvae":
        #CKPT_DIR = "models/imagenet/hs_32_4/best.pt"#.format(hidden_size)
        model = VectorQuantizedVAE(num_channels, hidden_size, k, enc_type,
                                   dec_type)
    elif model == "vae":
        model = VAE(num_channels, hidden_size, enc_type, dec_type)
    elif model == "acai":
        model = ACAI(resolution, num_channels, hidden_size, enc_type, dec_type)
    else:
        model = None

    if model != "supervised":
        imgclassifier = ImgClassifier(model, hidden_size, resolution,
                                      num_classes)
    else:
        imgclassifier = SupervisedImgClassifier(hidden_size, enc_type,
                                                resolution, num_classes)

    if hidden_size > 0:
        ckpt = torch.load(CKPT_DIR)
        model.load_state_dict(ckpt["model"])

    return imgclassifier
コード例 #9
0
def main(args):
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Start Time =", current_time)

    if args.input == "MNIST":
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize([0.5], [0.5])])

        # Define the train & test dataSets
        test_set = datasets.MNIST("MNIST",
                                  train=False,
                                  download=True,
                                  transform=transform)

        # Define the data loaders
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=100,
                                                  shuffle=True,
                                                  pin_memory=True)

        real_batch, _ = next(iter(test_loader))

    else:
        real_batch = torch.normal(mean=0, std=1, size=(100, 1, 28, 28))

    model = VectorQuantizedVAE(1, args.hidden_size_vae, args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    generated = generate_samples(real_batch, model, args)
    save_image(make_grid(generated, nrow=10),
               './generatedImages/{0}.png'.format(args.filename))

    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("End Time =", current_time)
コード例 #10
0
),
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)(
    '../data/{}/'.format(DATASET),
    train=False,
    transform=preproc_transform,
),
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=True)

model = VectorQuantizedVAE(INPUT_DIM, DIM, K).to(DEVICE)
print(model)
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


def train():
    train_loss = []
    for batch_idx, (x, _) in enumerate(train_loader):
        start_time = time.time()
        x = x.to(DEVICE)

        opt.zero_grad()

        x_tilde, z_e_x, z_q_x = model(x)
        z_q_x.retain_grad()
コード例 #11
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'PubTabNet':
        transform = transforms.Compose(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_dataset = PubTabNet(args.data_folder,
                                  args.data_name,
                                  'TRAIN',
                                  transform=transform)
        test_dataset = PubTabNet(args.data_folder,
                                 args.data_name,
                                 'VAL',
                                 transform=transform)
        valid_dataset = test_dataset
        num_channels = 3
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=4, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=4,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer, epoch)
        loss, _ = test(valid_loader, model, args, writer)
        eprint('Validataion loss at epoch %d: Loss = %.4f' % (epoch, loss))

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=4,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
コード例 #12
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Save the label encoder
    # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
    #     json.dump(train_dataset._label_encoder, f)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=32).to(args.device)
    # args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device)
    optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        print(epoch)
        train(train_loader, model, prior, optimizer, args, writer)
        # The validation loss is not properly computed since
        # the classes in the train and valid splits of Mini-Imagenet
        # do not overlap.
        loss = test(valid_loader, model, prior, args, writer)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open(save_filename, 'wb') as f:
                torch.save(prior.state_dict(), f)
コード例 #13
0
ファイル: vqvae.py プロジェクト: MioChiu/pytorch_vqvae
def main(args):
    # set manualseed
    random.seed(args.manualseed)
    torch.manual_seed(args.manualseed)
    torch.cuda.manual_seed_all(args.manualseed)
    np.random.seed(args.manualseed)
    torch.backends.cudnn.deterministic = True

    writer = SummaryWriter(args.log_dir)
    save_filename = args.save_dir

    # if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
    #     transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     if args.dataset == 'mnist':
    #         assert args.nc == 1
    #         # Define the train & test datasets
    #         train_dataset = datasets.MNIST(args.dataroot,
    #                                        train=True,
    #                                        download=True,
    #                                        transform=transform)
    #         test_dataset = datasets.MNIST(args.dataroot,
    #                                       train=False,
    #                                       transform=transform)

    #     elif args.dataset == 'fashion-mnist':
    #         # Define the train & test datasets
    #         train_dataset = datasets.FashionMNIST(args.dataroot,
    #                                               train=True,
    #                                               download=True,
    #                                               transform=transform)
    #         test_dataset = datasets.FashionMNIST(args.dataroot,
    #                                              train=False,
    #                                              transform=transform)

    #     elif args.dataset == 'cifar10':
    #         # Define the train & test datasets
    #         train_dataset = datasets.CIFAR10(args.dataroot,
    #                                          train=True,
    #                                          download=True,
    #                                          transform=transform)
    #         test_dataset = datasets.CIFAR10(args.dataroot,
    #                                         train=False,
    #                                         transform=transform)

    #     valid_dataset = test_dataset
    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.dataroot,
    #                                  train=True,
    #                                  download=True,
    #                                  transform=transform)
    #     valid_dataset = MiniImagenet(args.dataroot,
    #                                  valid=True,
    #                                  download=True,
    #                                  transform=transform)
    #     test_dataset = MiniImagenet(args.dataroot,
    #                                 test=True,
    #                                 download=True,
    #                                 transform=transform)

    # Define the data loaders
    # train_loader = torch.utils.data.DataLoader(train_dataset,
    #                                            batch_size=args.batch_size,
    #                                            shuffle=False,
    #                                            num_workers=args.num_workers,
    #                                            pin_memory=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dataset,
    #                                            batch_size=args.batch_size,
    #                                            shuffle=False,
    #                                            drop_last=True,
    #                                            num_workers=args.num_workers,
    #                                            pin_memory=True)
    # test_loader = torch.utils.data.DataLoader(test_dataset,
    #                                           batch_size=16,
    #                                           shuffle=True)

    dataloader = load_data(args)
    train_loader = dataloader['train']
    valid_loader = dataloader['valid']
    test_loader = dataloader['test']

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(valid_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(args.nc, args.hidden_size,
                               args.k).to(args.device)
    if len(args.gpu_ids) > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=args.gpu_ids,
                                      output_device=args.gpu_ids[0])

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    # for epoch in range(args.num_epochs):
    for epoch in range(1, args.num_epochs + 1):
        train(train_loader, model, optimizer, args, writer, epoch)
        loss, _ = test(test_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 1) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        if (epoch % args.save_step) == 0:
            with open('{0}/model_{1}.pt'.format(save_filename, epoch),
                      'wb') as f:
                torch.save(model.state_dict(), f)
コード例 #14
0
def main(args):
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Start Time =", current_time)

    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    transform = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

    # Define the train & test dataSets
    train_set = datasets.MNIST(args.data_folder, train=True,
                               download=True, transform=transform)
    test_set = datasets.MNIST(args.data_folder, train=False,
                              download=True, transform=transform)
    num_channels = 1

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_set, num_workers=args.num_workers,
                                              batch_size=16, shuffle=False)

    # Fixed images for TensorBoard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    writer.add_graph(model, fixed_images.to(args.device))  # get model structure on tensorboard

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('reconstruction at start', grid, 0)

    img_list = []
    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(test_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)

        writer.add_image('reconstruction at epoch {:f}'.format(epoch + 1), grid, epoch + 1)
        print("loss = {:f} at epoch {:f}".format(loss, epoch + 1))
        writer.add_scalar('loss/testing_loss', loss, epoch + 1)
        img_list.append(grid)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(model.state_dict(), f)


    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("End Time =", current_time)
コード例 #15
0
ファイル: model_summary.py プロジェクト: sidwa/ae_thesis
def model_summary(model_type,
                  img_res,
                  hidden_size,
                  enc_type,
                  dec_type,
                  loss,
                  batch_size,
                  device=torch.device("cuda:1"),
                  verbose=True):
    pattern = re.compile(r"Params size \(MB\):(.*)\n")
    pattern2 = re.compile(r"Forward/backward pass size \(MB\):(.*)\n")
    input_dim = 3
    enc_input_size = (input_dim, img_res, img_res)
    dec_input_size = (hidden_size, img_res // 4, img_res // 4)
    pdb.set_trace()
    if verbose:
        print(f"model:{model_type}")
        print(f"depth:{enc_type}_{dec_type}")

    if model_type == "acai":
        model = ACAI(img_res, input_dim, hidden_size, enc_type,
                     dec_type).to(device)
    elif model_type == "vqvae":
        model = VectorQuantizedVAE(input_dim,
                                   hidden_size,
                                   enc_type=enc_type,
                                   dec_type=dec_type).to(device)
    elif model_type == "vae":
        model = VAE(input_dim,
                    hidden_size,
                    enc_type=enc_type,
                    dec_type=dec_type).to(device)

    encoder_summary, _ = torchsummary.summary_string(model.encoder,
                                                     enc_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    decoder_summary, _ = torchsummary.summary_string(model.decoder,
                                                     dec_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    if verbose:
        print(encoder_summary)
        print(decoder_summary)

    discriminators = {}

    if model_type == "acai":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["interp_disc"] = (disc_param_size, disc_forward_size)
    if loss == "gan":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif loss == "comp":
        disc = AnchorComparator(input_dim * 2, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_2" in loss:
        disc = ClubbedPermutationComparator(input_dim * 2, img_res,
                                            "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_6" in loss:
        disc = FullPermutationComparator(input_dim * 2, img_res,
                                         "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)

    encoder_param_size = float(re.search(pattern, encoder_summary).group(1))
    encoder_forward_size = float(re.search(pattern2, encoder_summary).group(1))
    decoder_param_size = float(re.search(pattern, decoder_summary).group(1))
    decoder_forward_size = float(re.search(pattern2, decoder_summary).group(1))

    if verbose:
        if "ACAI" in str(type(model)):
            print(
                f"discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        if loss == "gan":
            print(
                f"reconstruction discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        print(
            f"encoder:\n\tparams:{encoder_param_size}\n\tforward:{encoder_forward_size}"
        )
        print(
            f"decoder:\n\tparams:{decoder_param_size}\n\tforward:{decoder_forward_size}"
        )

    encoder = {"params": encoder_param_size, "forward": encoder_forward_size}
    decoder = {"params": decoder_param_size, "forward": decoder_forward_size}

    return encoder, decoder, discriminators
コード例 #16
0
ファイル: vqvae.py プロジェクト: mihirp1998/vqvae_2d
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'clevr':
        transform = transforms.Compose([
            # transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = "/projects/katefgroup/datasets/clevr_veggies/"
            dataset_name = '/home/mprabhud/dataset/clevr_veggies'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level)
        num_channels = 3
    elif args.dataset == 'carla':
        if args.use_depth:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5, 0.5),
                                     (0.5, 0.5, 0.5, 0.5))
            ])
            num_channels = 4
        else:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            num_channels = 3
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = '/home/shamitl/datasets/carla'
            dataset_name = "/projects/katefgroup/datasets/carla/"
            dataset_name = '/home/mprabhud/dataset/carla'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)

    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.data_folder, train=True,
    #         download=True, transform=transform)
    #     # valid_dataset = MiniImagenet(args.data_folder, valid=True,
    #     #     download=True, transform=transform)
    #     # test_dataset = MiniImagenet(args.data_folder, test=True,
    #     #     download=True, transform=transform)
    #     # num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)
    fixed_images, _ = next(iter(train_loader))
    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.object_level, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.load_model is not "":
        with open(args.load_model, 'rb') as f:
            state_dict = torch.load(f)
            model.load_state_dict(state_dict)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    # st()
    best_loss = -1.
    for epoch in range(args.num_epochs):
        if not args.test_mode:
            train(train_loader, model, optimizer, args, writer, epoch)
            # st()
            loss, _ = test_old(valid_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
            # st()
            with open('{0}/recent.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
            if (epoch == 0) or (loss < best_loss):
                best_loss = loss
                with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                    torch.save(model.state_dict(), f)
            # else:
            #     print("nothing")
        else:
            test(train_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
コード例 #17
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)

    num_channels = 3
    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    model.to(args.device)

    # Fixed images for Tensorboard
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    discriminators = {}

    input_dim = 3
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        patience=args.lr_patience,
        factor=0.5,
        threshold=args.threshold,
        threshold_mode="abs",
        min_lr=1e-7)

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)
    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******")
            print(err, file=sys.stderr)
            print(f"batch_size:{args.batch_size}", file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators)

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

        # early stop check
        # if val_loss_dict["recons_loss"] - best_loss < args.threshold:
        # 	stop_patience -= 1
        # else:
        # 	stop_patience = args.stop_patience

        # if stop_patience == 0:
        # 	print("training early stopped!")
        # 	break

        ae_lr_scheduler.step(val_loss_dict["recons_loss"])
        if args.recons_loss != "mse":
            recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
コード例 #18
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Save the label encoder
    # with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
    #     json.dump(train_dataset._label_encoder, f)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=10).to(args.device)
    state_dict = torch.load('models/mnist-10-prior/prior.pt')
    prior.load_state_dict(state_dict)

    # label_ = [int(i%10) for i in range(16)]
    # label_ = torch.tensor(label_).cuda()
    # reconstruction = model.decode(prior.generate(label=label_, batch_size=16,sample_index=10))
    # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    # writer.add_image('Partial sampling result', grid, 0)
    for epoch in range(args.num_epochs):

        label_ = [int(i % 10) for i in range(16)]
        label_ = torch.tensor(label_).cuda()
        reconstruction = model.decode(
            prior.generate(label=label_, batch_size=16, sample_index=20))
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('Full sampling result', grid, epoch + 1)
コード例 #19
0
ファイル: vqvae.py プロジェクト: mehdidc/pytorch-vqvae
def main(args):
    writer = SummaryWriter("./logs/{0}".format(args.output_folder))
    save_filename = "./models/{0}".format(args.output_folder)

    if args.dataset in ["mnist", "fashion-mnist", "cifar10"]:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        if args.dataset == "mnist":
            # Define the train & test datasets
            train_dataset = datasets.MNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.MNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "fashion-mnist":
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.FashionMNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "cifar10":
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.CIFAR10(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == "miniimagenet":
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(
            args.data_folder, train=True, download=True, transform=transform
        )
        valid_dataset = MiniImagenet(
            args.data_folder, valid=True, download=True, transform=transform
        )
        test_dataset = MiniImagenet(
            args.data_folder, test=True, download=True, transform=transform
        )
        num_channels = 3
    else:
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(args.image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = ImageFolder(
            os.path.join(args.data_folder, "train"), transform=transform
        )
        valid_dataset = ImageFolder(
            os.path.join(args.data_folder, "val"), transform=transform
        )
        test_dataset = valid_dataset
        num_channels = 3

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    save_image(fixed_grid, "true.png")
    writer.add_image("original", fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    save_image(grid, "rec.png")
    writer.add_image("reconstruction", grid, 0)

    best_loss = -1
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print(epoch, "test loss: ", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
        save_image(grid, "rec.png")

        writer.add_image("reconstruction", grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open("{0}/best.pt".format(save_filename), "wb") as f:
                torch.save(model.state_dict(), f)
        with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f:
            torch.save(model.state_dict(), f)
コード例 #20
0
def main(args):

    if args.dataset == 'miniimagenet':
        readable_labels = {
            38: 'organ',
            42: 'prayer_rug',
            31: 'file',
            61: 'cliff',
            58: 'consomme',
            59: 'hotdog',
            21: 'aircraft_carrier',
            14: 'French_bulldog',
            28: 'cocktail_shaker',
            63: 'ear',
            3: 'green_mamba',
            4: 'harvestman',
            17: 'Arctic_fox',
            32: 'fire_screen',
            11: 'komondor',
            43: 'reel',
            18: 'ladybug',
            45: 'snorkel',
            24: 'beer_bottle',
            36: 'lipstick',
            5: 'toucan',
            0: 'house_finch',
            16: 'miniature_poodle',
            50: 'tile_roof',
            15: 'Newfoundland',
            46: 'solar_dish',
            10: 'Gordon_setter',
            7: 'dugong',
            52: 'unicycle',
            20: 'rock_beauty',
            48: 'stage',
            22: 'ashcan',
            34: 'hair_slide',
            30: 'dome',
            13: 'Tibetan_mastiff',
            53: 'upright',
            62: 'bolete',
            2: 'triceratops',
            40: 'pencil_box',
            26: 'chime',
            47: 'spider_web',
            51: 'tobacco_shop',
            60: 'orange',
            49: 'tank',
            8: 'Walker_hound',
            23: 'barrel',
            6: 'jellyfish',
            33: 'frying_pan',
            9: 'Saluki',
            37: 'oboe',
            1: 'robin',
            19: 'three-toed_sloth',
            39: 'parallel_bars',
            55: 'worm_fence',
            27: 'clog',
            41: 'photocopier',
            25: 'carousel',
            29: 'dishrag',
            57: 'street_sign',
            35: 'holster',
            12: 'boxer',
            56: 'yawl',
            54: 'wok',
            44: 'slot'
        }

    elif args.dataset == 'cifar10':
        readable_labels = {
            0: 'airplane',
            1: 'automobile',
            2: 'bird',
            3: 'cat',
            4: 'deer',
            5: 'dog',
            6: 'frog',
            7: 'horse',
            8: 'ship',
            9: 'truck'
        }

    writer = SummaryWriter('./VQVAE/logs/{0}'.format(args.output_folder))
    save_filename = './VQVAE/models/{0}/prior.pt'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size_vae,
                               args.k).to(args.device)
    with open(args.model, 'rb') as f:
        state_dict = torch.load(f)
        model.load_state_dict(state_dict)
    model.eval()

    if args.dataset == 'miniimagenet':
        print("number of training classes:", len(train_dataset._label_encoder))
        n_classes = len(train_dataset._label_encoder)
        shape = (32, 32)
        yrange = range(n_classes)
        csv_filename = 'miniimagenet_generated.csv'
        sample_size = 25
    elif args.dataset == 'cifar10':
        print("number of training classes:", 10)
        n_classes = 10
        shape = (8, 8)
        yrange = range(n_classes)
        csv_filename = 'cifar10_generated.csv'
        sample_size = 1000

    prior = GatedPixelCNN(args.k,
                          args.hidden_size_prior,
                          args.num_layers,
                          n_classes=n_classes).to(args.device)

    with open(args.prior, 'rb') as f:
        state_dict = torch.load(f)
        prior.load_state_dict(state_dict)
    prior.eval()

    # maximum number of kept dimensions
    max_num_dst = 0
    max_num_sparsemax = 0

    with torch.no_grad():

        f = open(
            './VQVAE/models/{0}/{1}'.format(args.output_folder, csv_filename),
            'w')

        with f:

            writer = csv.writer(f)
            writer.writerow(['filename', 'label'])

            for y in tqdm(yrange):
                label = torch.tensor([y])
                label = label.to(args.device)
                z = prior.generate(label=label,
                                   shape=shape,
                                   batch_size=sample_size)
                x = model.decode(z)

                for im in range(x.shape[0]):
                    save_image(x.cpu()[im],
                               './data/{0}/dataset_softmax/{1}_{2}.jpg'.format(
                                   args.output_folder,
                                   str(y).zfill(2),
                                   str(im).zfill(3)),
                               range=(-1, 1),
                               normalize=True)
                    if args.dataset == 'miniimagenet':
                        y_str = list(train_dataset._label_encoder.keys())[list(
                            train_dataset._label_encoder.values()).index(y)]
                        writer.writerow([
                            '{0}_{1}.jpg'.format(
                                str(y).zfill(2),
                                str(im).zfill(3)),
                            str(y_str)
                        ])
                    elif args.dataset == 'cifar10':
                        writer.writerow([
                            '{0}_{1}.jpg'.format(
                                str(y).zfill(2),
                                str(im).zfill(3)),
                            str(y)
                        ])

                z, num_dst = prior.generate_dst(label=label,
                                                shape=shape,
                                                batch_size=sample_size)
                x = model.decode(z)

                if num_dst > max_num_dst:
                    max_num_dst = num_dst

                for im in range(x.shape[0]):
                    save_image(x.cpu()[im],
                               './data/{0}/dataset_dst/{1}_{2}.jpg'.format(
                                   args.output_folder,
                                   str(y).zfill(2),
                                   str(im).zfill(3)),
                               range=(-1, 1),
                               normalize=True)
                z, num_sparsemax = prior.generate_sparsemax(
                    label=label, shape=shape, batch_size=sample_size)
                x = model.decode(z)

                if num_sparsemax > max_num_sparsemax:
                    max_num_sparsemax = num_sparsemax

                for im in range(x.shape[0]):
                    save_image(
                        x.cpu()[im],
                        './data/{0}/dataset_sparsemax/{1}_{2}.jpg'.format(
                            args.output_folder,
                            str(y).zfill(2),
                            str(im).zfill(3)),
                        range=(-1, 1),
                        normalize=True)

    pkl.dump([max_num_dst, max_num_sparsemax],
             open(
                 "max_num_" + str(args.dataset) + "_" + str(sample_size) +
                 "_correct_dst.pkl", "wb"))
コード例 #21
0
    train=True,
    download=True,
    transform=preproc_transform,
),
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)(
    '../data/{}/'.format(DATASET), train=False, transform=preproc_transform),
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=True)

autoencoder = VectorQuantizedVAE(INPUT_DIM, VAE_DIM, K).to(DEVICE)
autoencoder.load_state_dict(torch.load('models/{}_vqvae.pt'.format(DATASET)))
autoencoder.eval()

model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


def train():
    train_loss = []
    for batch_idx, (x, label) in enumerate(train_loader):
        start_time = time.time()
        x = x.to(DEVICE)
        label = label.to(DEVICE)
コード例 #22
0
ファイル: vqvae.py プロジェクト: Tinemar/pytorch-vqvae
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)
    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    if args.ckp != "":
        model.load_state_dict(torch.load(args.ckp))
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    if args.tmodel != '':
        net = vgg.VGG('VGG19')
        net = net.to(args.device)
        net = torch.nn.DataParallel(net)
        checkpoint = torch.load(args.tmodel)
        net.load_state_dict(checkpoint['net'])
        target_model = net
    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        print(epoch)
        # if epoch<100:
        #     args.lr = 1e-5
        # if epoch>100 and epoch< 400:
        #     args.lr = 2e-5
        train(train_loader, model, target_model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("test loss:", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
コード例 #23
0
ファイル: my_vqvae.py プロジェクト: mmendiet/VQ-VAE_RL
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset == 'atari':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(84),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        dictDir = args.data_folder+'data_dict/'
        dataDir = args.data_folder+'data_traj/'
        #dictDir = args.data_folder+'test_dict/'
        #dataDir = args.data_folder+'test_traj/'
        all_partition = defaultdict(list)
        all_labels = defaultdict(list)
        # Datasets
        for dictionary in os.listdir(dictDir):
            ########
            if args.out_game not in dictionary:
                #######
                dfile = open(dictDir+dictionary, 'rb')
                d = pickle.load(dfile)
                dfile.close()
                if("partition" in dictionary):
                    for key in d:
                        all_partition[key] += d[key]
                elif("labels" in dictionary):
                    for key in d:
                        all_labels[key] = d[key]
                else:
                    print("Error: Unexpected data dictionary")
        #partition = # IDs
        #labels = # Labels

        # Generators
        training_set = Dataset(all_partition['train'], all_labels, dataDir)
        train_loader = data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True,
            num_workers=args.num_workers, pin_memory=True)

        validation_set = Dataset(all_partition['validation'], all_labels, dataDir)
        valid_loader = data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=True, drop_last=False,
            num_workers=args.num_workers, pin_memory=True)
        test_loader = data.DataLoader(validation_set, batch_size=16, shuffle=True)
        input_channels = 13
        output_channels = 4

    # Define the data loaders
    # train_loader = torch.utils.data.DataLoader(train_dataset,
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.num_workers, pin_memory=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dataset,
    #     batch_size=args.batch_size, shuffle=False, drop_last=True,
    #     num_workers=args.num_workers, pin_memory=True)
    # test_loader = torch.utils.data.DataLoader(test_dataset,
    #     batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, fixed_y = next(iter(test_loader))
    fixed_y = fixed_y[:,0:3,:,:]
    fixed_grid = make_grid(fixed_y, nrow=8, range=(0, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(input_channels, output_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    reconstruction_image = reconstruction[:,0:3,:,:]
    grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    print("Starting to train...")
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("Finished Epoch: " + str(epoch) + "   Validation Loss: " + str(loss))
        reconstruction = generate_samples(fixed_images, model, args)
        reconstruction_image = reconstruction[:,0:3,:,:]
        grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(model.state_dict(), f)