예제 #1
0
def main(args):
    base_path = util.make_directory(args.name, args.ow)
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    trainloader, testloader, in_channels = get_datasets(args)
    # Model
    print('Building model..')
    net = RealNVP(num_scales=args.num_scales, in_channels=in_channels, mid_channels=64, num_blocks=args.num_blocks)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark
        pass

    if args.resume:
        # Load checkpoint.
        ckpt_path = base_path / 'ckpts'
        best_path_ckpt = ckpt_path / 'best.pth.tar'
        print(f'Resuming from checkpoint at {best_path_ckpt}')
        checkpoint = torch.load(best_path_ckpt)
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples, in_channels, base_path)
예제 #2
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    #trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
    #trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    #testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
    #testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    kwargs = {'num_workers':8,'pin_memory':False}
    #trainloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)
    #testloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=False,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),])

    #dset = CustomImageFolder('data/CelebA', transform)
    #trainloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)
    #testloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)

    trainloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='train',download=True,transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)
    testloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='test',transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
예제 #3
0
def plugin_estimator_training_loop(real_nvp_model, dataloader, learning_rate,
                                   optim, device, total_iters,
                                   checkpoint_intervals, batchsize, algorithm,
                                   weight_decay, max_grad_norm, save_dir,
                                   save_suffix):
    """
    Function to train the RealNVP model using
    a plugin mean estimation algorithm for total_iters with learning_rate
    """
    param_groups = util.get_param_groups(real_nvp_model,
                                         weight_decay,
                                         norm_suffix='weight_g')
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)
    optimizer = optimizer_cons(param_groups)

    loss_fn = RealNVPLoss()
    flag = False
    iteration = 0
    while not flag:
        for x, _ in dataloader:
            # Update iteration counter
            iteration += 1

            x = x.to(device)
            z, sldj = real_nvp_model(x, reverse=False)
            unaggregated_loss = loss_fn(z, sldj, aggregate=False)
            if algorithm.__name__ == 'mean':
                agg_loss = unaggregated_loss.mean()
                agg_loss.backward()
            else:
                # First sample gradients
                sgradients = utils.gradient_sampler(unaggregated_loss,
                                                    real_nvp_model)
                # Then get the estimate with the mean estimation algorithm
                stoc_grad = algorithm(sgradients)
                # Perform the update of .grad attributes
                with torch.no_grad():
                    utils.update_grad_attributes(
                        real_nvp_model.parameters(),
                        torch.as_tensor(stoc_grad, device=device))
            # Clip gradient if required
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            # Perform the update
            optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    real_nvp_model.state_dict(),
                    f"{save_dir}/real_nvp_{algorithm.__name__}_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return real_nvp_model
예제 #4
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    img_file = list(os.listdir('/Users/test/Desktop/Master_Thesis_Flow/real-nvp/data/GPR_Data/B12/B12_Pictures_LS'))
    print(type(img_file))
    
    trainset = GPRDataset(img_file, root_dir='data/GPR_Data/B12/B12_Pictures_LS', train=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = GPRDataset(img_file, root_dir='data/GPR_Data/B12/B12_Pictures_LS', train=False, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
예제 #5
0
def main(args, train):
    # Set up main device and scale batch size
    device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu'
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # No normalization applied, since model expects inputs in (0, 1)
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = FlowPlusPlus(scales=[(0, 4), (2, 3)],
                       in_shape=(3, 32, 32),
                       mid_channels=args.num_channels,
                       num_blocks=args.num_blocks,
                       num_dequant_blocks=args.num_dequant_blocks,
                       num_components=args.num_components,
                       use_attn=args.use_attn,
                       drop_prob=args.drop_prob)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    start_epoch = 0
    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at save/best.pth.tar...')
        assert os.path.isdir('save'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('save/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        global global_step
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']
        global_step = start_epoch * len(trainset)

    loss_fn = util.NLLLoss().to(device)
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)
    warm_up = args.warm_up * args.batch_size
    scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        #train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, args.max_grad_norm)
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm, args, scheduler)
예제 #6
0
def streaming_approx_training_loop(real_nvp_model, dataloader, learning_rate,
                                   optim, device, total_iters,
                                   checkpoint_intervals, alpha, batchsize,
                                   n_discard, weight_decay, max_grad_norm,
                                   save_dir, save_suffix):
    """
    Function to train the RealNVP model using
    the streaming rank-1 approximation with algorithm for total_iters
    with optimizer optim
    """
    param_groups = util.get_param_groups(real_nvp_model,
                                         weight_decay,
                                         norm_suffix='weight_g')
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)
    optimizer = optimizer_cons(param_groups)

    loss_fn = RealNVPLoss()
    flag = False
    iteration = 0
    top_eigvec, top_eigval, running_mean = None, None, None

    real_nvp_model.train()
    while not flag:
        for x, _ in dataloader:
            # Update iteration counter
            iteration += 1

            x = x.to(device)
            z, sldj = real_nvp_model(x, reverse=False)
            unaggregated_loss = loss_fn(z, sldj, aggregate=False)
            # First sample gradients
            sgradients = utils.gradient_sampler(unaggregated_loss,
                                                real_nvp_model)
            # Then get the estimate with the previously computed direction
            stoc_grad, top_eigvec, top_eigval, running_mean = streaming_update_algorithm(
                sgradients,
                n_discard=n_discard,
                top_v=top_eigvec,
                top_lambda=top_eigval,
                old_mean=running_mean,
                alpha=alpha)
            # Perform the update of .grad attributes
            with torch.no_grad():
                utils.update_grad_attributes(
                    real_nvp_model.parameters(),
                    torch.as_tensor(stoc_grad, device=device))
            # Clip gradient if required
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            # Perform the update
            optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    real_nvp_model.state_dict(),
                    f"{save_dir}/real_nvp_streaming_approx_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return real_nvp_model
예제 #7
0
def main(args):
    # Set up main device and scale batch size
    device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu'
    args.batch_size *= max(1, len(args.gpu_ids))
    torch.autograd.set_detect_anomaly(True)
    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # No normalization applied, since model expects inputs in (0, 1)
    transform_train = transforms.Compose([
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    #
    testset = torchvision.datasets.CIFAR10(root='data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    ################################################################

    # Load CelebA instead of CIFAR10 :

    image_size = 32
    batch_size = 16
    workers = 4

    #    transforms_celeb = transforms.Compose([
    #                                   transforms.Resize(image_size),
    #                                   transforms.CenterCrop(image_size),
    #                                   transforms.ToTensor()
    #                               ])

    #    dataroot_train = r"./data/train"
    #    dataroot_test = r"./data/validation"

    #    trainset = torchvision.datasets.ImageFolder(root=dataroot_train, transform=transforms_celeb)
    #    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    #    testset = torchvision.datasets.ImageFolder(root=dataroot_test, transform=transforms_celeb)
    #    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    #    trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms_celeb)
    #    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    #    testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms_celeb)
    #    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    import matplotlib.pyplot as plt

    #    def imshow(img):
    #        img = img / 2 + 0.5
    #        npimg = img.numpy()
    #        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #        plt.show()

    #    dataiter = iter(trainloader)
    #    images = dataiter.next()

    # show images
    # print(images[0])
    # imshow(torchvision.utils.make_grid(images[0]))

    # Model
    print('Building model..')
    net = FlowPlusPlus(scales=[(0, 4), (2, 3)],
                       in_shape=(1, 32, 32),
                       mid_channels=args.num_channels,
                       num_blocks=args.num_blocks,
                       num_dequant_blocks=args.num_dequant_blocks,
                       num_components=args.num_components,
                       use_attn=args.use_attn,
                       drop_prob=args.drop_prob)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    start_epoch = 0
    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at save/best.pth.tar...')
        assert os.path.isdir('save'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('save/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        global global_step
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']
        global_step = start_epoch * len(trainset)

    loss_fn = util.NLLLoss().to(device)
    param_groups = util.get_param_groups(net,
                                         args.weight_decay,
                                         norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)
    warm_up = args.warm_up * args.batch_size
    scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
              args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples,
             args.save_dir)
예제 #8
0
def main(args=None):
    if args:
        opt = parser.parse_args(args)
    else:
        opt = parser.parse_args()

    print(opt)

    print("loading dataset")
    if opt.dataset == "imagenet32":
        train_dataset = Imagenet32DatasetDiscrete(
            train=not opt.train_on_val,
            max_size=1 if opt.debug else opt.train_size)
        val_dataset = Imagenet32DatasetDiscrete(
            train=0,
            max_size=1 if opt.debug else opt.val_size,
            start_idx=opt.val_start_idx)
    else:
        assert opt.dataset == "cifar10"
        train_dataset = CIFAR10Dataset(train=not opt.train_on_val,
                                       max_size=1 if opt.debug else -1)
        val_dataset = CIFAR10Dataset(train=0, max_size=1 if opt.debug else -1)

    print("creating dataloaders")
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
    )

    print("Len train : {}, val : {}".format(len(train_dataloader),
                                            len(val_dataloader)))

    device = torch.device("cuda") if (
        torch.cuda.is_available() and opt.use_cuda) else torch.device("cpu")
    print("Device is {}".format(device))

    print("Loading models on device...")

    # Initialize embedder
    if opt.conditioning == 'unconditional':
        encoder = UnconditionalClassEmbedding()
    elif opt.conditioning == "bert":
        encoder = BERTEncoder()
    else:
        assert opt.conditioning == "one-hot"
        encoder = OneHotClassEmbedding(train_dataset.n_classes)

    # generative_model = ConditionalPixelCNNpp(embd_size=encoder.embed_size, img_shape=train_dataset.image_shape,
    #                                          nr_resnet=opt.n_resnet, nr_filters=opt.n_filters,
    #                                          nr_logistic_mix=3 if train_dataset.image_shape[0] == 1 else 10)

    generative_model = FlowPlusPlus(
        scales=[(0, 4), (2, 3)],
        # in_shape=(3, 32, 32),
        in_shape=train_dataset.image_shape,
        mid_channels=opt.n_filters,
        num_blocks=opt.num_blocks,
        num_dequant_blocks=opt.num_dequant_blocks,
        num_components=opt.num_components,
        use_attn=opt.use_attn,
        drop_prob=opt.drop_prob,
        condition_embd_size=encoder.embed_size)

    generative_model = generative_model.to(device)
    encoder = encoder.to(device)
    print("Models loaded on device")

    # Configure data loader

    print("dataloaders loaded")
    # Optimizers
    # optimizer = torch.optim.Adam(generative_model.parameters(), lr=opt.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_decay)
    param_groups = util.get_param_groups(generative_model,
                                         opt.lr_decay,
                                         norm_suffix='weight_g')
    optimizer = torch.optim.Adam(param_groups, lr=opt.lr)
    warm_up = opt.warm_up * opt.batch_size
    scheduler = lr_scheduler.LambdaLR(optimizer,
                                      lambda s: min(1., s / warm_up))
    # create output directory

    os.makedirs(os.path.join(opt.output_dir, "models"), exist_ok=True)
    os.makedirs(os.path.join(opt.output_dir, "tensorboard"), exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, "tensorboard"))

    global global_step
    global_step = 0

    # ----------
    #  Training
    # ----------
    if opt.train:
        train(model=generative_model,
              embedder=encoder,
              optimizer=optimizer,
              scheduler=scheduler,
              train_loader=train_dataloader,
              val_loader=val_dataloader,
              opt=opt,
              writer=writer,
              device=device)
    else:
        assert opt.model_checkpoint is not None, 'no model checkpoint specified'
        print("Loading model from state dict...")
        load_model(opt.model_checkpoint, generative_model)
        print("Model loaded.")
        sample_images_full(generative_model,
                           encoder,
                           opt.output_dir,
                           dataloader=val_dataloader,
                           device=device)
        eval(model=generative_model,
             embedder=encoder,
             test_loader=val_dataloader,
             opt=opt,
             writer=writer,
             device=device)
예제 #9
0
def main(args):
    global best_loss
    global cnt_early_stop

    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        MyToTensor()
        # transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        MyToTensor()
        #transforms.ToTensor()
    ])

    assert DATASET in ['mnist', 'cifar10']
    if DATASET == 'mnist':
        dataset_picker = torchvision.datasets.MNIST
    else:
        dataset_picker = torchvision.datasets.CIFAR10

    trainset = dataset_picker(root='data', train=True, download=True, transform=transform_train)
    testset = dataset_picker(root='data', train=False, download=True, transform=transform_test)

    valset = copy.deepcopy(trainset)

    train_val_idx = np.random.choice(np.arange(trainset.data.shape[0]), size=N_TRAIN+N_VAL, replace=False)
    train_idx = train_val_idx[:N_TRAIN]
    val_idx = train_val_idx[N_TRAIN:]
    valset.data = valset.data[val_idx]
    trainset.data = trainset.data[train_idx]


    test_idx = np.random.choice(np.arange(testset.data.shape[0]), size=N_TEST, replace=False)
    testset.data = testset.data[test_idx]

    if DATASET == 'mnist':
        trainset.targets = trainset.targets[train_idx]
        valset.targets = valset.targets[val_idx]
        testset.targets = testset.targets[test_idx]
    else:
        trainset.targets = np.array(trainset.targets)[train_idx]
        valset.targets = np.array(valset.targets)[val_idx]
        testset.targets = np.array(testset.targets)[test_idx]

    # noisytestset = copy.deepcopy(testset)
    if DATASET == 'mnist':
        trainset.data, trainset.targets = get_noisy_data(trainset.data, trainset.targets)
        valset.data, valset.targets = get_noisy_data(valset.data, valset.targets)
        testset.data, testset.targets = get_noisy_data(testset.data, testset.targets)

    else:
        noisy_samples = np.random.rand(N_TEST * N_NOISY_SAMPLES_PER_TEST_SAMPLE, 32, 32, 3) - 0.5
        noisytestset.data = np.tile( noisytestset.data, (N_NOISY_SAMPLES_PER_TEST_SAMPLE, 1, 1, 1)) + noisy_samples
        noisytestset.targets = np.tile(noisytestset.targets, [N_NOISY_SAMPLES_PER_TEST_SAMPLE])

    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    valloader = data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    # noisytestloader = data.DataLoader(noisytestset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    if DATASET == 'mnist':
        net = RealNVP(num_scales=2, in_channels=1, mid_channels=64, num_blocks=8)
    else:
        net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        best_loss = checkpoint['val_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        prev_best_loss = best_loss
        test(epoch, net, valloader, device, loss_fn, args.num_samples, 'val')
        if best_loss < prev_best_loss:
            cnt_early_stop = 0
        else:
            cnt_early_stop += 1
        if cnt_early_stop >= PATIENCE:
            break
        # test(epoch, net, testloader, device, loss_fn, args.num_samples, 'test')
        # test(epoch, net, noisytestloader, device, loss_fn, args.num_samples, 'noisytest')

    checkpoint = torch.load('ckpts/best.pth.tar')
    net.load_state_dict(checkpoint['net'])

    pixelwise_ll = -pixelwise_test(net, testloader, device, loss_fn, args.num_samples)
    pixelwise_ll = pixelwise_ll.reshape([-1, 28, 28])
    os.makedirs('pixelwise_loglikelihood', exist_ok=True)
    for i in range(len(pixelwise_ll)):
        tmp = np.exp( pixelwise_ll[i] )
        tmp = 255 * ( tmp / np.max(tmp) )
        im = Image.fromarray(tmp)
        im.convert('RGB').save('pixelwise_loglikelihood/' + str(i) + '.png')
    for i, (x,_) in enumerate(testloader):
        x_np = np.array(x.cpu(), dtype=np.float)
        for j in range(args.num_samples):
            im = Image.fromarray(255 * x_np[j].reshape([28,28]))
            im.convert('RGB').save('pixelwise_loglikelihood/' + str(j) + '-orig.png')
        break
    test(epoch, net, testloader, device, loss_fn, args.num_samples, 'test')
예제 #10
0
    def __init__(self, args):
        """
        Args:
            args: Configuration args passed in via the command line.
        """
        super(Flow2Flow, self).__init__()
        self.device = 'cuda' if len(args.gpu_ids) > 0 else 'cpu'
        self.gpu_ids = args.gpu_ids
        self.is_training = args.is_training

        self.in_channels = args.num_channels
        self.out_channels = 4 ** (args.num_scales - 1) * self.in_channels

        # Set up RealNVP generators (g_src: X <-> Z, g_tgt: Y <-> Z)
        self.g_src = RealNVP(num_scales=args.num_scales,
                             in_channels=args.num_channels,
                             mid_channels=args.num_channels_g,
                             num_blocks=args.num_blocks,
                             un_normalize_x=True,
                             no_latent=False)
        util.init_model(self.g_src, init_method=args.initializer)
        self.g_tgt = RealNVP(num_scales=args.num_scales,
                             in_channels=args.num_channels,
                             mid_channels=args.num_channels_g,
                             num_blocks=args.num_blocks,
                             un_normalize_x=True,
                             no_latent=False)
        util.init_model(self.g_tgt, init_method=args.initializer)

        if self.is_training:
            # Set up discriminators
            self.d_tgt = PatchGAN(args)  # Answers Q "is this tgt image real?"
            self.d_src = PatchGAN(args)  # Answers Q "is this src image real?"

            self._data_parallel()

            # Set up loss functions
            self.max_grad_norm = args.clip_gradient
            self.lambda_mle = args.lambda_mle
            self.mle_loss_fn = RealNVPLoss()
            self.gan_loss_fn = util.GANLoss(device=self.device, use_least_squares=True)

            self.clamp_jacobian = args.clamp_jacobian
            self.jc_loss_fn = util.JacobianClampingLoss(args.jc_lambda_min, args.jc_lambda_max)

            # Set up optimizers
            g_src_params = util.get_param_groups(self.g_src, args.weight_norm_l2, norm_suffix='weight_g')
            g_tgt_params = util.get_param_groups(self.g_tgt, args.weight_norm_l2, norm_suffix='weight_g')
            self.opt_g = torch.optim.Adam(chain(g_src_params, g_tgt_params),
                                          lr=args.rnvp_lr,
                                          betas=(args.rnvp_beta_1, args.rnvp_beta_2))
            self.opt_d = torch.optim.Adam(chain(self.d_tgt.parameters(), self.d_src.parameters()),
                                          lr=args.lr,
                                          betas=(args.beta_1, args.beta_2))
            self.optimizers = [self.opt_g, self.opt_d]
            self.schedulers = [util.get_lr_scheduler(opt, args) for opt in self.optimizers]

            # Setup image mixers
            buffer_capacity = 50 if args.use_mixer else 0
            self.src2tgt_buffer = util.ImageBuffer(buffer_capacity)  # Buffer of generated tgt images
            self.tgt2src_buffer = util.ImageBuffer(buffer_capacity)  # Buffer of generated src images
        else:
            self._data_parallel()

        # Images in flow src -> lat -> tgt
        self.src = None
        self.src2lat = None
        self.src2tgt = None

        # Images in flow tgt -> lat -> src
        self.tgt = None
        self.tgt2lat = None
        self.tgt2src = None

        # Jacobian clamping tensors
        self.src_jc = None
        self.tgt_jc = None
        self.src2tgt_jc = None
        self.tgt2src_jc = None

        # Discriminator loss
        self.loss_d_tgt = None
        self.loss_d_src = None
        self.loss_d = None

        # Generator GAN loss
        self.loss_gan_src = None
        self.loss_gan_tgt = None
        self.loss_gan = None

        # Generator MLE loss
        self.loss_mle_src = None
        self.loss_mle_tgt = None
        self.loss_mle = None

        # Jacobian Clamping loss
        self.loss_jc_src = None
        self.loss_jc_tgt = None
        self.loss_jc = None

        # Generator total loss
        self.loss_g = None
예제 #11
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(
        args.gpu_ids) > 0 else 'cpu'
    print(device)
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.backends.cudnn.version())
    print(torch.version.cuda)
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])

    transform_test = transforms.Compose([transforms.ToTensor()])
    print(os.getcwd())
    trainset = torchvision.datasets.CIFAR10(root='data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(root='data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    total_num_parameters = 0
    for name, value in net.named_parameters():
        print(name, value.shape)
        total_num_parameters += np.prod(value.shape)
        print(name + ' has paramteres: ' + str(np.prod(value.shape)))

    print('Total number of parameters', total_num_parameters)

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net,
                                         args.weight_decay,
                                         norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn,
              args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)