Exemple #1
0
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    discriminators = {}

    if args.recons_loss == "gan":
        recons_disc = Discriminator(input_dim, args.img_res,
                                    args.input_type).to(args.device)
        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators, True)
        # if args.weights == "init" and epoch==1:
        # 	epoch+=1
        # 	break

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
Exemple #2
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)
    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    if args.ckp != "":
        model.load_state_dict(torch.load(args.ckp))
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    if args.tmodel != '':
        net = vgg.VGG('VGG19')
        net = net.to(args.device)
        net = torch.nn.DataParallel(net)
        checkpoint = torch.load(args.tmodel)
        net.load_state_dict(checkpoint['net'])
        target_model = net
    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        print(epoch)
        # if epoch<100:
        #     args.lr = 1e-5
        # if epoch>100 and epoch< 400:
        #     args.lr = 2e-5
        train(train_loader, model, target_model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("test loss:", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
Exemple #3
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'PubTabNet':
        transform = transforms.Compose(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_dataset = PubTabNet(args.data_folder,
                                  args.data_name,
                                  'TRAIN',
                                  transform=transform)
        test_dataset = PubTabNet(args.data_folder,
                                 args.data_name,
                                 'VAL',
                                 transform=transform)
        valid_dataset = test_dataset
        num_channels = 3
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=4, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=4,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer, epoch)
        loss, _ = test(valid_loader, model, args, writer)
        eprint('Validataion loss at epoch %d: Loss = %.4f' % (epoch, loss))

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=4,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
Exemple #4
0
def main(args):
    # set manualseed
    random.seed(args.manualseed)
    torch.manual_seed(args.manualseed)
    torch.cuda.manual_seed_all(args.manualseed)
    np.random.seed(args.manualseed)
    torch.backends.cudnn.deterministic = True

    writer = SummaryWriter(args.log_dir)
    save_filename = args.save_dir

    # if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
    #     transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     if args.dataset == 'mnist':
    #         assert args.nc == 1
    #         # Define the train & test datasets
    #         train_dataset = datasets.MNIST(args.dataroot,
    #                                        train=True,
    #                                        download=True,
    #                                        transform=transform)
    #         test_dataset = datasets.MNIST(args.dataroot,
    #                                       train=False,
    #                                       transform=transform)

    #     elif args.dataset == 'fashion-mnist':
    #         # Define the train & test datasets
    #         train_dataset = datasets.FashionMNIST(args.dataroot,
    #                                               train=True,
    #                                               download=True,
    #                                               transform=transform)
    #         test_dataset = datasets.FashionMNIST(args.dataroot,
    #                                              train=False,
    #                                              transform=transform)

    #     elif args.dataset == 'cifar10':
    #         # Define the train & test datasets
    #         train_dataset = datasets.CIFAR10(args.dataroot,
    #                                          train=True,
    #                                          download=True,
    #                                          transform=transform)
    #         test_dataset = datasets.CIFAR10(args.dataroot,
    #                                         train=False,
    #                                         transform=transform)

    #     valid_dataset = test_dataset
    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.dataroot,
    #                                  train=True,
    #                                  download=True,
    #                                  transform=transform)
    #     valid_dataset = MiniImagenet(args.dataroot,
    #                                  valid=True,
    #                                  download=True,
    #                                  transform=transform)
    #     test_dataset = MiniImagenet(args.dataroot,
    #                                 test=True,
    #                                 download=True,
    #                                 transform=transform)

    # Define the data loaders
    # train_loader = torch.utils.data.DataLoader(train_dataset,
    #                                            batch_size=args.batch_size,
    #                                            shuffle=False,
    #                                            num_workers=args.num_workers,
    #                                            pin_memory=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dataset,
    #                                            batch_size=args.batch_size,
    #                                            shuffle=False,
    #                                            drop_last=True,
    #                                            num_workers=args.num_workers,
    #                                            pin_memory=True)
    # test_loader = torch.utils.data.DataLoader(test_dataset,
    #                                           batch_size=16,
    #                                           shuffle=True)

    dataloader = load_data(args)
    train_loader = dataloader['train']
    valid_loader = dataloader['valid']
    test_loader = dataloader['test']

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(valid_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(args.nc, args.hidden_size,
                               args.k).to(args.device)
    if len(args.gpu_ids) > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=args.gpu_ids,
                                      output_device=args.gpu_ids[0])

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    # for epoch in range(args.num_epochs):
    for epoch in range(1, args.num_epochs + 1):
        train(train_loader, model, optimizer, args, writer, epoch)
        loss, _ = test(test_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 1) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        if (epoch % args.save_step) == 0:
            with open('{0}/model_{1}.pt'.format(save_filename, epoch),
                      'wb') as f:
                torch.save(model.state_dict(), f)
Exemple #5
0
def main(args):
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Start Time =", current_time)

    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    transform = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

    # Define the train & test dataSets
    train_set = datasets.MNIST(args.data_folder, train=True,
                               download=True, transform=transform)
    test_set = datasets.MNIST(args.data_folder, train=False,
                              download=True, transform=transform)
    num_channels = 1

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_set, num_workers=args.num_workers,
                                              batch_size=16, shuffle=False)

    # Fixed images for TensorBoard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    writer.add_graph(model, fixed_images.to(args.device))  # get model structure on tensorboard

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('reconstruction at start', grid, 0)

    img_list = []
    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(test_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)

        writer.add_image('reconstruction at epoch {:f}'.format(epoch + 1), grid, epoch + 1)
        print("loss = {:f} at epoch {:f}".format(loss, epoch + 1))
        writer.add_scalar('loss/testing_loss', loss, epoch + 1)
        img_list.append(grid)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(model.state_dict(), f)


    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("End Time =", current_time)
Exemple #6
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)

    num_channels = 3
    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    model.to(args.device)

    # Fixed images for Tensorboard
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    discriminators = {}

    input_dim = 3
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        patience=args.lr_patience,
        factor=0.5,
        threshold=args.threshold,
        threshold_mode="abs",
        min_lr=1e-7)

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)
    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******")
            print(err, file=sys.stderr)
            print(f"batch_size:{args.batch_size}", file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators)

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

        # early stop check
        # if val_loss_dict["recons_loss"] - best_loss < args.threshold:
        # 	stop_patience -= 1
        # else:
        # 	stop_patience = args.stop_patience

        # if stop_patience == 0:
        # 	print("training early stopped!")
        # 	break

        ae_lr_scheduler.step(val_loss_dict["recons_loss"])
        if args.recons_loss != "mse":
            recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
                                           shuffle=False,
                                           num_workers=NUM_WORKERS,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(eval('datasets.' + DATASET)(
    '../data/{}/'.format(DATASET),
    train=False,
    transform=preproc_transform,
),
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=NUM_WORKERS,
                                          pin_memory=True)

model = VectorQuantizedVAE(INPUT_DIM, DIM, K).to(DEVICE)
print(model)
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


def train():
    train_loss = []
    for batch_idx, (x, _) in enumerate(train_loader):
        start_time = time.time()
        x = x.to(DEVICE)

        opt.zero_grad()

        x_tilde, z_e_x, z_q_x = model(x)
        z_q_x.retain_grad()

        loss_recons = F.mse_loss(x_tilde, x)
        loss_recons.backward(retain_graph=True)
Exemple #8
0
def main(args):
    writer = SummaryWriter("./logs/{0}".format(args.output_folder))
    save_filename = "./models/{0}".format(args.output_folder)

    if args.dataset in ["mnist", "fashion-mnist", "cifar10"]:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        if args.dataset == "mnist":
            # Define the train & test datasets
            train_dataset = datasets.MNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.MNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "fashion-mnist":
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.FashionMNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "cifar10":
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.CIFAR10(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == "miniimagenet":
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(
            args.data_folder, train=True, download=True, transform=transform
        )
        valid_dataset = MiniImagenet(
            args.data_folder, valid=True, download=True, transform=transform
        )
        test_dataset = MiniImagenet(
            args.data_folder, test=True, download=True, transform=transform
        )
        num_channels = 3
    else:
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(args.image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = ImageFolder(
            os.path.join(args.data_folder, "train"), transform=transform
        )
        valid_dataset = ImageFolder(
            os.path.join(args.data_folder, "val"), transform=transform
        )
        test_dataset = valid_dataset
        num_channels = 3

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    save_image(fixed_grid, "true.png")
    writer.add_image("original", fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    save_image(grid, "rec.png")
    writer.add_image("reconstruction", grid, 0)

    best_loss = -1
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print(epoch, "test loss: ", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
        save_image(grid, "rec.png")

        writer.add_image("reconstruction", grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open("{0}/best.pt".format(save_filename), "wb") as f:
                torch.save(model.state_dict(), f)
        with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f:
            torch.save(model.state_dict(), f)
Exemple #9
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'clevr':
        transform = transforms.Compose([
            # transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = "/projects/katefgroup/datasets/clevr_veggies/"
            dataset_name = '/home/mprabhud/dataset/clevr_veggies'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level)
        num_channels = 3
    elif args.dataset == 'carla':
        if args.use_depth:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5, 0.5),
                                     (0.5, 0.5, 0.5, 0.5))
            ])
            num_channels = 4
        else:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            num_channels = 3
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = '/home/shamitl/datasets/carla'
            dataset_name = "/projects/katefgroup/datasets/carla/"
            dataset_name = '/home/mprabhud/dataset/carla'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)

    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.data_folder, train=True,
    #         download=True, transform=transform)
    #     # valid_dataset = MiniImagenet(args.data_folder, valid=True,
    #     #     download=True, transform=transform)
    #     # test_dataset = MiniImagenet(args.data_folder, test=True,
    #     #     download=True, transform=transform)
    #     # num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)
    fixed_images, _ = next(iter(train_loader))
    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.object_level, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.load_model is not "":
        with open(args.load_model, 'rb') as f:
            state_dict = torch.load(f)
            model.load_state_dict(state_dict)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    # st()
    best_loss = -1.
    for epoch in range(args.num_epochs):
        if not args.test_mode:
            train(train_loader, model, optimizer, args, writer, epoch)
            # st()
            loss, _ = test_old(valid_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
            # st()
            with open('{0}/recent.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
            if (epoch == 0) or (loss < best_loss):
                best_loss = loss
                with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                    torch.save(model.state_dict(), f)
            # else:
            #     print("nothing")
        else:
            test(train_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
Exemple #10
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset == 'atari':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(84),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        dictDir = args.data_folder+'data_dict/'
        dataDir = args.data_folder+'data_traj/'
        #dictDir = args.data_folder+'test_dict/'
        #dataDir = args.data_folder+'test_traj/'
        all_partition = defaultdict(list)
        all_labels = defaultdict(list)
        # Datasets
        for dictionary in os.listdir(dictDir):
            ########
            if args.out_game not in dictionary:
                #######
                dfile = open(dictDir+dictionary, 'rb')
                d = pickle.load(dfile)
                dfile.close()
                if("partition" in dictionary):
                    for key in d:
                        all_partition[key] += d[key]
                elif("labels" in dictionary):
                    for key in d:
                        all_labels[key] = d[key]
                else:
                    print("Error: Unexpected data dictionary")
        #partition = # IDs
        #labels = # Labels

        # Generators
        training_set = Dataset(all_partition['train'], all_labels, dataDir)
        train_loader = data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True,
            num_workers=args.num_workers, pin_memory=True)

        validation_set = Dataset(all_partition['validation'], all_labels, dataDir)
        valid_loader = data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=True, drop_last=False,
            num_workers=args.num_workers, pin_memory=True)
        test_loader = data.DataLoader(validation_set, batch_size=16, shuffle=True)
        input_channels = 13
        output_channels = 4

    # Define the data loaders
    # train_loader = torch.utils.data.DataLoader(train_dataset,
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.num_workers, pin_memory=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dataset,
    #     batch_size=args.batch_size, shuffle=False, drop_last=True,
    #     num_workers=args.num_workers, pin_memory=True)
    # test_loader = torch.utils.data.DataLoader(test_dataset,
    #     batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, fixed_y = next(iter(test_loader))
    fixed_y = fixed_y[:,0:3,:,:]
    fixed_grid = make_grid(fixed_y, nrow=8, range=(0, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(input_channels, output_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    reconstruction_image = reconstruction[:,0:3,:,:]
    grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    print("Starting to train...")
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("Finished Epoch: " + str(epoch) + "   Validation Loss: " + str(loss))
        reconstruction = generate_samples(fixed_images, model, args)
        reconstruction_image = reconstruction[:,0:3,:,:]
        grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(model.state_dict(), f)