Esempio n. 1
0
def main(args):
    base_path = util.make_directory(args.name, args.ow)
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    trainloader, testloader, in_channels = get_datasets(args)
    # Model
    print('Building model..')
    net = RealNVP(num_scales=args.num_scales, in_channels=in_channels, mid_channels=64, num_blocks=args.num_blocks)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark
        pass

    if args.resume:
        # Load checkpoint.
        ckpt_path = base_path / 'ckpts'
        best_path_ckpt = ckpt_path / 'best.pth.tar'
        print(f'Resuming from checkpoint at {best_path_ckpt}')
        checkpoint = torch.load(best_path_ckpt)
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples, in_channels, base_path)
Esempio n. 2
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    #trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
    #trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    #testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
    #testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    kwargs = {'num_workers':8,'pin_memory':False}
    #trainloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)
    #testloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=False,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),])

    #dset = CustomImageFolder('data/CelebA', transform)
    #trainloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)
    #testloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)

    trainloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='train',download=True,transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)
    testloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='test',transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Esempio n. 3
0
def plugin_estimator_training_loop(real_nvp_model, dataloader, learning_rate,
                                   optim, device, total_iters,
                                   checkpoint_intervals, batchsize, algorithm,
                                   weight_decay, max_grad_norm, save_dir,
                                   save_suffix):
    """
    Function to train the RealNVP model using
    a plugin mean estimation algorithm for total_iters with learning_rate
    """
    param_groups = util.get_param_groups(real_nvp_model,
                                         weight_decay,
                                         norm_suffix='weight_g')
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)
    optimizer = optimizer_cons(param_groups)

    loss_fn = RealNVPLoss()
    flag = False
    iteration = 0
    while not flag:
        for x, _ in dataloader:
            # Update iteration counter
            iteration += 1

            x = x.to(device)
            z, sldj = real_nvp_model(x, reverse=False)
            unaggregated_loss = loss_fn(z, sldj, aggregate=False)
            if algorithm.__name__ == 'mean':
                agg_loss = unaggregated_loss.mean()
                agg_loss.backward()
            else:
                # First sample gradients
                sgradients = utils.gradient_sampler(unaggregated_loss,
                                                    real_nvp_model)
                # Then get the estimate with the mean estimation algorithm
                stoc_grad = algorithm(sgradients)
                # Perform the update of .grad attributes
                with torch.no_grad():
                    utils.update_grad_attributes(
                        real_nvp_model.parameters(),
                        torch.as_tensor(stoc_grad, device=device))
            # Clip gradient if required
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            # Perform the update
            optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    real_nvp_model.state_dict(),
                    f"{save_dir}/real_nvp_{algorithm.__name__}_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return real_nvp_model
Esempio n. 4
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    img_file = list(os.listdir('/Users/test/Desktop/Master_Thesis_Flow/real-nvp/data/GPR_Data/B12/B12_Pictures_LS'))
    print(type(img_file))
    
    trainset = GPRDataset(img_file, root_dir='data/GPR_Data/B12/B12_Pictures_LS', train=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = GPRDataset(img_file, root_dir='data/GPR_Data/B12/B12_Pictures_LS', train=False, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Esempio n. 5
0
def streaming_approx_training_loop(real_nvp_model, dataloader, learning_rate,
                                   optim, device, total_iters,
                                   checkpoint_intervals, alpha, batchsize,
                                   n_discard, weight_decay, max_grad_norm,
                                   save_dir, save_suffix):
    """
    Function to train the RealNVP model using
    the streaming rank-1 approximation with algorithm for total_iters
    with optimizer optim
    """
    param_groups = util.get_param_groups(real_nvp_model,
                                         weight_decay,
                                         norm_suffix='weight_g')
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)
    optimizer = optimizer_cons(param_groups)

    loss_fn = RealNVPLoss()
    flag = False
    iteration = 0
    top_eigvec, top_eigval, running_mean = None, None, None

    real_nvp_model.train()
    while not flag:
        for x, _ in dataloader:
            # Update iteration counter
            iteration += 1

            x = x.to(device)
            z, sldj = real_nvp_model(x, reverse=False)
            unaggregated_loss = loss_fn(z, sldj, aggregate=False)
            # First sample gradients
            sgradients = utils.gradient_sampler(unaggregated_loss,
                                                real_nvp_model)
            # Then get the estimate with the previously computed direction
            stoc_grad, top_eigvec, top_eigval, running_mean = streaming_update_algorithm(
                sgradients,
                n_discard=n_discard,
                top_v=top_eigvec,
                top_lambda=top_eigval,
                old_mean=running_mean,
                alpha=alpha)
            # Perform the update of .grad attributes
            with torch.no_grad():
                utils.update_grad_attributes(
                    real_nvp_model.parameters(),
                    torch.as_tensor(stoc_grad, device=device))
            # Clip gradient if required
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            # Perform the update
            optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    real_nvp_model.state_dict(),
                    f"{save_dir}/real_nvp_streaming_approx_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return real_nvp_model
Esempio n. 6
0
def main(args):
    global best_loss
    global cnt_early_stop

    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        MyToTensor()
        # transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        MyToTensor()
        #transforms.ToTensor()
    ])

    assert DATASET in ['mnist', 'cifar10']
    if DATASET == 'mnist':
        dataset_picker = torchvision.datasets.MNIST
    else:
        dataset_picker = torchvision.datasets.CIFAR10

    trainset = dataset_picker(root='data', train=True, download=True, transform=transform_train)
    testset = dataset_picker(root='data', train=False, download=True, transform=transform_test)

    valset = copy.deepcopy(trainset)

    train_val_idx = np.random.choice(np.arange(trainset.data.shape[0]), size=N_TRAIN+N_VAL, replace=False)
    train_idx = train_val_idx[:N_TRAIN]
    val_idx = train_val_idx[N_TRAIN:]
    valset.data = valset.data[val_idx]
    trainset.data = trainset.data[train_idx]


    test_idx = np.random.choice(np.arange(testset.data.shape[0]), size=N_TEST, replace=False)
    testset.data = testset.data[test_idx]

    if DATASET == 'mnist':
        trainset.targets = trainset.targets[train_idx]
        valset.targets = valset.targets[val_idx]
        testset.targets = testset.targets[test_idx]
    else:
        trainset.targets = np.array(trainset.targets)[train_idx]
        valset.targets = np.array(valset.targets)[val_idx]
        testset.targets = np.array(testset.targets)[test_idx]

    # noisytestset = copy.deepcopy(testset)
    if DATASET == 'mnist':
        trainset.data, trainset.targets = get_noisy_data(trainset.data, trainset.targets)
        valset.data, valset.targets = get_noisy_data(valset.data, valset.targets)
        testset.data, testset.targets = get_noisy_data(testset.data, testset.targets)

    else:
        noisy_samples = np.random.rand(N_TEST * N_NOISY_SAMPLES_PER_TEST_SAMPLE, 32, 32, 3) - 0.5
        noisytestset.data = np.tile( noisytestset.data, (N_NOISY_SAMPLES_PER_TEST_SAMPLE, 1, 1, 1)) + noisy_samples
        noisytestset.targets = np.tile(noisytestset.targets, [N_NOISY_SAMPLES_PER_TEST_SAMPLE])

    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    valloader = data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    # noisytestloader = data.DataLoader(noisytestset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    if DATASET == 'mnist':
        net = RealNVP(num_scales=2, in_channels=1, mid_channels=64, num_blocks=8)
    else:
        net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        best_loss = checkpoint['val_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        prev_best_loss = best_loss
        test(epoch, net, valloader, device, loss_fn, args.num_samples, 'val')
        if best_loss < prev_best_loss:
            cnt_early_stop = 0
        else:
            cnt_early_stop += 1
        if cnt_early_stop >= PATIENCE:
            break
        # test(epoch, net, testloader, device, loss_fn, args.num_samples, 'test')
        # test(epoch, net, noisytestloader, device, loss_fn, args.num_samples, 'noisytest')

    checkpoint = torch.load('ckpts/best.pth.tar')
    net.load_state_dict(checkpoint['net'])

    pixelwise_ll = -pixelwise_test(net, testloader, device, loss_fn, args.num_samples)
    pixelwise_ll = pixelwise_ll.reshape([-1, 28, 28])
    os.makedirs('pixelwise_loglikelihood', exist_ok=True)
    for i in range(len(pixelwise_ll)):
        tmp = np.exp( pixelwise_ll[i] )
        tmp = 255 * ( tmp / np.max(tmp) )
        im = Image.fromarray(tmp)
        im.convert('RGB').save('pixelwise_loglikelihood/' + str(i) + '.png')
    for i, (x,_) in enumerate(testloader):
        x_np = np.array(x.cpu(), dtype=np.float)
        for j in range(args.num_samples):
            im = Image.fromarray(255 * x_np[j].reshape([28,28]))
            im.convert('RGB').save('pixelwise_loglikelihood/' + str(j) + '-orig.png')
        break
    test(epoch, net, testloader, device, loss_fn, args.num_samples, 'test')
Esempio n. 7
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(
        args.gpu_ids) > 0 else 'cpu'
    print(device)
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.backends.cudnn.version())
    print(torch.version.cuda)
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])

    transform_test = transforms.Compose([transforms.ToTensor()])
    print(os.getcwd())
    trainset = torchvision.datasets.CIFAR10(root='data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(root='data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    total_num_parameters = 0
    for name, value in net.named_parameters():
        print(name, value.shape)
        total_num_parameters += np.prod(value.shape)
        print(name + ' has paramteres: ' + str(np.prod(value.shape)))

    print('Total number of parameters', total_num_parameters)

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net,
                                         args.weight_decay,
                                         norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn,
              args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)