Ejemplo n.º 1
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)
    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    if args.ckp != "":
        model.load_state_dict(torch.load(args.ckp))
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    if args.tmodel != '':
        net = vgg.VGG('VGG19')
        net = net.to(args.device)
        net = torch.nn.DataParallel(net)
        checkpoint = torch.load(args.tmodel)
        net.load_state_dict(checkpoint['net'])
        target_model = net
    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        print(epoch)
        # if epoch<100:
        #     args.lr = 1e-5
        # if epoch>100 and epoch< 400:
        #     args.lr = 2e-5
        train(train_loader, model, target_model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("test loss:", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
Ejemplo n.º 2
0
def main(args):
    # set manualseed
    random.seed(args.manualseed)
    torch.manual_seed(args.manualseed)
    torch.cuda.manual_seed_all(args.manualseed)
    np.random.seed(args.manualseed)
    torch.backends.cudnn.deterministic = True

    writer = SummaryWriter(args.log_dir)
    save_filename = args.save_dir

    # if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
    #     transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     if args.dataset == 'mnist':
    #         assert args.nc == 1
    #         # Define the train & test datasets
    #         train_dataset = datasets.MNIST(args.dataroot,
    #                                        train=True,
    #                                        download=True,
    #                                        transform=transform)
    #         test_dataset = datasets.MNIST(args.dataroot,
    #                                       train=False,
    #                                       transform=transform)

    #     elif args.dataset == 'fashion-mnist':
    #         # Define the train & test datasets
    #         train_dataset = datasets.FashionMNIST(args.dataroot,
    #                                               train=True,
    #                                               download=True,
    #                                               transform=transform)
    #         test_dataset = datasets.FashionMNIST(args.dataroot,
    #                                              train=False,
    #                                              transform=transform)

    #     elif args.dataset == 'cifar10':
    #         # Define the train & test datasets
    #         train_dataset = datasets.CIFAR10(args.dataroot,
    #                                          train=True,
    #                                          download=True,
    #                                          transform=transform)
    #         test_dataset = datasets.CIFAR10(args.dataroot,
    #                                         train=False,
    #                                         transform=transform)

    #     valid_dataset = test_dataset
    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.dataroot,
    #                                  train=True,
    #                                  download=True,
    #                                  transform=transform)
    #     valid_dataset = MiniImagenet(args.dataroot,
    #                                  valid=True,
    #                                  download=True,
    #                                  transform=transform)
    #     test_dataset = MiniImagenet(args.dataroot,
    #                                 test=True,
    #                                 download=True,
    #                                 transform=transform)

    # Define the data loaders
    # train_loader = torch.utils.data.DataLoader(train_dataset,
    #                                            batch_size=args.batch_size,
    #                                            shuffle=False,
    #                                            num_workers=args.num_workers,
    #                                            pin_memory=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dataset,
    #                                            batch_size=args.batch_size,
    #                                            shuffle=False,
    #                                            drop_last=True,
    #                                            num_workers=args.num_workers,
    #                                            pin_memory=True)
    # test_loader = torch.utils.data.DataLoader(test_dataset,
    #                                           batch_size=16,
    #                                           shuffle=True)

    dataloader = load_data(args)
    train_loader = dataloader['train']
    valid_loader = dataloader['valid']
    test_loader = dataloader['test']

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(valid_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(args.nc, args.hidden_size,
                               args.k).to(args.device)
    if len(args.gpu_ids) > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=args.gpu_ids,
                                      output_device=args.gpu_ids[0])

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    # for epoch in range(args.num_epochs):
    for epoch in range(1, args.num_epochs + 1):
        train(train_loader, model, optimizer, args, writer, epoch)
        loss, _ = test(test_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=8,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 1) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        if (epoch % args.save_step) == 0:
            with open('{0}/model_{1}.pt'.format(save_filename, epoch),
                      'wb') as f:
                torch.save(model.state_dict(), f)
Ejemplo n.º 3
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'PubTabNet':
        transform = transforms.Compose(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        train_dataset = PubTabNet(args.data_folder,
                                  args.data_name,
                                  'TRAIN',
                                  transform=transform)
        test_dataset = PubTabNet(args.data_folder,
                                 args.data_name,
                                 'VAL',
                                 transform=transform)
        valid_dataset = test_dataset
        num_channels = 3
    elif args.dataset == 'miniimagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(args.data_folder,
                                     train=True,
                                     download=True,
                                     transform=transform)
        valid_dataset = MiniImagenet(args.data_folder,
                                     valid=True,
                                     download=True,
                                     transform=transform)
        test_dataset = MiniImagenet(args.data_folder,
                                    test=True,
                                    download=True,
                                    transform=transform)
        num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=4, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=4,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer, epoch)
        loss, _ = test(valid_loader, model, args, writer)
        eprint('Validataion loss at epoch %d: Loss = %.4f' % (epoch, loss))

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(),
                         nrow=4,
                         range=(-1, 1),
                         normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1),
                  'wb') as f:
            torch.save(model.state_dict(), f)
Ejemplo n.º 4
0
def main(args):
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Start Time =", current_time)

    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    transform = transforms.Compose([
        transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

    # Define the train & test dataSets
    train_set = datasets.MNIST(args.data_folder, train=True,
                               download=True, transform=transform)
    test_set = datasets.MNIST(args.data_folder, train=False,
                              download=True, transform=transform)
    num_channels = 1

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.num_workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_set, num_workers=args.num_workers,
                                              batch_size=16, shuffle=False)

    # Fixed images for TensorBoard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    writer.add_graph(model, fixed_images.to(args.device))  # get model structure on tensorboard

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('reconstruction at start', grid, 0)

    img_list = []
    best_loss = -1.
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(test_loader, model, args, writer)

        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)

        writer.add_image('reconstruction at epoch {:f}'.format(epoch + 1), grid, epoch + 1)
        print("loss = {:f} at epoch {:f}".format(loss, epoch + 1))
        writer.add_scalar('loss/testing_loss', loss, epoch + 1)
        img_list.append(grid)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(model.state_dict(), f)


    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("End Time =", current_time)
Ejemplo n.º 5
0
def generate_samples():
    x, _ = test_loader.__iter__().next()
    x = x[:32].to(DEVICE)
    x_tilde, _, _ = model(x)

    x_cat = torch.cat([x, x_tilde], 0)
    images = (x_cat.cpu().data + 1) / 2

    save_image(images,
               'samples/vqvae_reconstructions_{}.png'.format(DATASET),
               nrow=8)


BEST_LOSS = 999
LAST_SAVED = -1
for epoch in range(1, N_EPOCHS):
    print("Epoch {}:".format(epoch))
    train()
    cur_loss, _ = test()

    if cur_loss <= BEST_LOSS:
        BEST_LOSS = cur_loss
        LAST_SAVED = epoch
        print("Saving model!")
        torch.save(model.state_dict(), 'models/{}_vqvae.pt'.format(DATASET))
    else:
        print("Not saving model! Last saved: {}".format(LAST_SAVED))

    generate_samples()
Ejemplo n.º 6
0
def main(args):
    writer = SummaryWriter("./logs/{0}".format(args.output_folder))
    save_filename = "./models/{0}".format(args.output_folder)

    if args.dataset in ["mnist", "fashion-mnist", "cifar10"]:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        if args.dataset == "mnist":
            # Define the train & test datasets
            train_dataset = datasets.MNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.MNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "fashion-mnist":
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.FashionMNIST(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 1
        elif args.dataset == "cifar10":
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(
                args.data_folder, train=True, download=True, transform=transform
            )
            test_dataset = datasets.CIFAR10(
                args.data_folder, train=False, transform=transform
            )
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == "miniimagenet":
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = MiniImagenet(
            args.data_folder, train=True, download=True, transform=transform
        )
        valid_dataset = MiniImagenet(
            args.data_folder, valid=True, download=True, transform=transform
        )
        test_dataset = MiniImagenet(
            args.data_folder, test=True, download=True, transform=transform
        )
        num_channels = 3
    else:
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(args.image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        # Define the train, valid & test datasets
        train_dataset = ImageFolder(
            os.path.join(args.data_folder, "train"), transform=transform
        )
        valid_dataset = ImageFolder(
            os.path.join(args.data_folder, "val"), transform=transform
        )
        test_dataset = valid_dataset
        num_channels = 3

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    save_image(fixed_grid, "true.png")
    writer.add_image("original", fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
    save_image(grid, "rec.png")
    writer.add_image("reconstruction", grid, 0)

    best_loss = -1
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print(epoch, "test loss: ", loss)
        reconstruction = generate_samples(fixed_images, model, args)
        grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
        save_image(grid, "rec.png")

        writer.add_image("reconstruction", grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open("{0}/best.pt".format(save_filename), "wb") as f:
                torch.save(model.state_dict(), f)
        with open("{0}/model_{1}.pt".format(save_filename, epoch + 1), "wb") as f:
            torch.save(model.state_dict(), f)
Ejemplo n.º 7
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset in ['mnist', 'fashion-mnist', 'cifar10']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        if args.dataset == 'mnist':
            # Define the train & test datasets
            train_dataset = datasets.MNIST(args.data_folder,
                                           train=True,
                                           download=True,
                                           transform=transform)
            test_dataset = datasets.MNIST(args.data_folder,
                                          train=False,
                                          transform=transform)
            num_channels = 1
        elif args.dataset == 'fashion-mnist':
            # Define the train & test datasets
            train_dataset = datasets.FashionMNIST(args.data_folder,
                                                  train=True,
                                                  download=True,
                                                  transform=transform)
            test_dataset = datasets.FashionMNIST(args.data_folder,
                                                 train=False,
                                                 transform=transform)
            num_channels = 1
        elif args.dataset == 'cifar10':
            # Define the train & test datasets
            train_dataset = datasets.CIFAR10(args.data_folder,
                                             train=True,
                                             download=True,
                                             transform=transform)
            test_dataset = datasets.CIFAR10(args.data_folder,
                                            train=False,
                                            transform=transform)
            num_channels = 3
        valid_dataset = test_dataset
    elif args.dataset == 'clevr':
        transform = transforms.Compose([
            # transforms.RandomResizedCrop(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = "/projects/katefgroup/datasets/clevr_veggies/"
            dataset_name = '/home/mprabhud/dataset/clevr_veggies'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level)
        num_channels = 3
    elif args.dataset == 'carla':
        if args.use_depth:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5, 0.5),
                                     (0.5, 0.5, 0.5, 0.5))
            ])
            num_channels = 4
        else:
            transform = transforms.Compose([
                # transforms.RandomResizedCrop(128),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            num_channels = 3
        import socket
        if "Alien" in socket.gethostname():
            dataset_name = "/media/mihir/dataset/clevr_veggies/"
        else:
            dataset_name = '/home/shamitl/datasets/carla'
            dataset_name = "/projects/katefgroup/datasets/carla/"
            dataset_name = '/home/mprabhud/dataset/carla'
        # Define the train, valid & test datasets
        train_dataset = Clevr(dataset_name,mod = args.modname\
            , train=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        valid_dataset = Clevr(dataset_name,mod = args.modname,\
         valid=True,transform=transform,object_level= args.object_level,use_depth=args.use_depth)
        test_dataset = Clevr(dataset_name,mod = args.modname,\
         test=True, transform=transform,object_level= args.object_level,use_depth=args.use_depth)

    # elif args.dataset == 'miniimagenet':
    #     transform = transforms.Compose([
    #         transforms.RandomResizedCrop(128),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #     ])
    #     # Define the train, valid & test datasets
    #     train_dataset = MiniImagenet(args.data_folder, train=True,
    #         download=True, transform=transform)
    #     # valid_dataset = MiniImagenet(args.data_folder, valid=True,
    #     #     download=True, transform=transform)
    #     # test_dataset = MiniImagenet(args.data_folder, test=True,
    #     #     download=True, transform=transform)
    #     # num_channels = 3

    # Define the data loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               drop_last=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=16,
                                              shuffle=True)
    fixed_images, _ = next(iter(train_loader))
    # Fixed images for Tensorboard
    fixed_images, _ = next(iter(test_loader))
    fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(num_channels, args.hidden_size,
                               args.object_level, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    if args.load_model is not "":
        with open(args.load_model, 'rb') as f:
            state_dict = torch.load(f)
            model.load_state_dict(state_dict)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    grid = make_grid(reconstruction.cpu(),
                     nrow=8,
                     range=(-1, 1),
                     normalize=True)
    writer.add_image('reconstruction', grid, 0)

    # st()
    best_loss = -1.
    for epoch in range(args.num_epochs):
        if not args.test_mode:
            train(train_loader, model, optimizer, args, writer, epoch)
            # st()
            loss, _ = test_old(valid_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
            # st()
            with open('{0}/recent.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
            if (epoch == 0) or (loss < best_loss):
                best_loss = loss
                with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                    torch.save(model.state_dict(), f)
            # else:
            #     print("nothing")
        else:
            test(train_loader, model, args, writer)
            reconstruction = generate_samples(fixed_images, model, args)
            grid = make_grid(reconstruction.cpu(),
                             nrow=8,
                             range=(-1, 1),
                             normalize=True)
            writer.add_image('reconstruction', grid, epoch + 1)
Ejemplo n.º 8
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    if args.dataset == 'atari':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(84),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        dictDir = args.data_folder+'data_dict/'
        dataDir = args.data_folder+'data_traj/'
        #dictDir = args.data_folder+'test_dict/'
        #dataDir = args.data_folder+'test_traj/'
        all_partition = defaultdict(list)
        all_labels = defaultdict(list)
        # Datasets
        for dictionary in os.listdir(dictDir):
            ########
            if args.out_game not in dictionary:
                #######
                dfile = open(dictDir+dictionary, 'rb')
                d = pickle.load(dfile)
                dfile.close()
                if("partition" in dictionary):
                    for key in d:
                        all_partition[key] += d[key]
                elif("labels" in dictionary):
                    for key in d:
                        all_labels[key] = d[key]
                else:
                    print("Error: Unexpected data dictionary")
        #partition = # IDs
        #labels = # Labels

        # Generators
        training_set = Dataset(all_partition['train'], all_labels, dataDir)
        train_loader = data.DataLoader(training_set, batch_size=args.batch_size, shuffle=True,
            num_workers=args.num_workers, pin_memory=True)

        validation_set = Dataset(all_partition['validation'], all_labels, dataDir)
        valid_loader = data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=True, drop_last=False,
            num_workers=args.num_workers, pin_memory=True)
        test_loader = data.DataLoader(validation_set, batch_size=16, shuffle=True)
        input_channels = 13
        output_channels = 4

    # Define the data loaders
    # train_loader = torch.utils.data.DataLoader(train_dataset,
    #     batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.num_workers, pin_memory=True)
    # valid_loader = torch.utils.data.DataLoader(valid_dataset,
    #     batch_size=args.batch_size, shuffle=False, drop_last=True,
    #     num_workers=args.num_workers, pin_memory=True)
    # test_loader = torch.utils.data.DataLoader(test_dataset,
    #     batch_size=16, shuffle=True)

    # Fixed images for Tensorboard
    fixed_images, fixed_y = next(iter(test_loader))
    fixed_y = fixed_y[:,0:3,:,:]
    fixed_grid = make_grid(fixed_y, nrow=8, range=(0, 1), normalize=True)
    writer.add_image('original', fixed_grid, 0)

    model = VectorQuantizedVAE(input_channels, output_channels, args.hidden_size, args.k).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Generate the samples first once
    reconstruction = generate_samples(fixed_images, model, args)
    reconstruction_image = reconstruction[:,0:3,:,:]
    grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True)
    writer.add_image('reconstruction', grid, 0)

    best_loss = -1.
    print("Starting to train...")
    for epoch in range(args.num_epochs):
        train(train_loader, model, optimizer, args, writer)
        loss, _ = test(valid_loader, model, args, writer)
        print("Finished Epoch: " + str(epoch) + "   Validation Loss: " + str(loss))
        reconstruction = generate_samples(fixed_images, model, args)
        reconstruction_image = reconstruction[:,0:3,:,:]
        grid = make_grid(reconstruction_image.cpu(), nrow=8, range=(0, 1), normalize=True)
        writer.add_image('reconstruction', grid, epoch + 1)

        if (epoch == 0) or (loss < best_loss):
            best_loss = loss
            with open('{0}/best.pt'.format(save_filename), 'wb') as f:
                torch.save(model.state_dict(), f)
        with open('{0}/model_{1}.pt'.format(save_filename, epoch + 1), 'wb') as f:
            torch.save(model.state_dict(), f)