Beispiel #1
0
            condition_embd = embedder(labels, captions)
            outputs = model.forward(imgs, condition_embd)
            loss = outputs['loss'].mean()
            bpd += loss / np.log(2)
    bpd /= len(test_loader)
    print("VAL bpd : {}".format(bpd))
    return bpd


if __name__ == "__main__":
    opt = parser.parse_args()
    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32Dataset(train=not opt.train_on_val,
                                          max_size=1 if opt.debug else -1)
        val_dataset = Imagenet32Dataset(train=0,
                                        max_size=1 if opt.debug else -1)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
Beispiel #2
0
def main(args=None):
    if args:
        opt = parser.parse_args(args)
    else:
        opt = parser.parse_args()

    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32Dataset(
            train=not opt.train_on_val,
            max_size=1 if opt.debug else opt.train_size)
        val_dataset = Imagenet32Dataset(
            train=0,
            max_size=1 if opt.debug else opt.val_size,
            start_idx=opt.val_start_idx)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    generative_model = ConditionalPixelCNNpp(
        embd_size=encoder.embed_size,
        img_shape=train_dataset.image_shape,
        nr_resnet=opt.n_resnet,
        nr_filters=opt.n_filters,
        nr_logistic_mix=3 if train_dataset.image_shape[0] == 1 else 10)

    generative_model = generative_model.to(device)
    encoder = encoder.to(device)
    print("Models loaded on device")

    # Configure data loader

    print("dataloaders loaded")
    # Optimizers
    optimizer = torch.optim.Adam(generative_model.parameters(), lr=opt.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_decay)
    # create output directory

    os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"))

    # ----------
    #  Training
    # ----------
    if opt.train:
        train(model=generative_model,
              embedder=encoder,
              optimizer=optimizer,
              scheduler=scheduler,
              train_loader=train_dataloader,
              val_loader=val_dataloader,
              opt=opt,
              writer=writer,
              device=device)
    else:
        assert opt.model_checkpoint is not None, 'no model checkpoint specified'
        print("Loading model from state dict...")
        load_model(opt.model_checkpoint, generative_model)
        print("Model loaded.")
        sample_images_full(generative_model,
                           encoder,
                           opt.output_dir,
                           dataloader=val_dataloader,
                           device=device)
        eval(model=generative_model,
             embedder=encoder,
             test_loader=val_dataloader,
             opt=opt,
             writer=writer,
             device=device)
    return np.mean(split_scores), np.std(split_scores)


if __name__ == '__main__':

    #NEW STUFF PixelCNN

    pdb.set_trace()

    n_resnet = 5
    n_filters = 160

    total = 2  #number of images to generate
    batch_size = 2  #number of images to generate at atime

    val_dataset = Imagenet32Dataset(train=0, max_size=total)

    dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    encoder = BERTEncoder()
    generative_model = ConditionalPixelCNNpp(
        embd_size=encoder.embed_size,
        img_shape=val_dataset.image_shape,
        nr_resnet=n_resnet,
        nr_filters=n_filters,
        nr_logistic_mix=3 if val_dataset.image_shape[0] == 1 else 10)

    device = torch.device(
Beispiel #4
0
parser.add_argument('--wass', type=bool, default=False, help='apply Wassersten GAN')

opt = parser.parse_args()
print(opt)

################# make output dirs #####################
os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
os.makedirs(os.path.join(opt.output_dir, "samples"), exist_ok=True)
os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)

writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"), comment='Cifar10')

################# load data #####################
print("loading dataset")
if opt.dataset == "imagenet":
    train_dataset = Imagenet32Dataset(train=True, max_size=1 if opt.debug else -1)
    val_dataset = Imagenet32Dataset(train=0, max_size=1 if opt.debug else -1)
elif opt.dataset == "cifar10":
    train_dataset = CIFAR10Dataset(train=True, max_size=1 if opt.debug else -1)
#    train_dataset = dset.CIFAR10(
#        root=opt.dataroot, download=True,
#        transform=transforms.Compose([
#            transforms.Scale(opt.imageSize),
#            transforms.ToTensor(),
#            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#        ]))
    val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)
elif opt.dataset == "coco":
    print("INFO: using coco")
    path2data="/home/ooo/Data/train2017"
    path2json="/home/ooo/Data/annotations_trainval2017/annotations/captions_train2017.json"