Exemple #1
0
from tools import get_embeddings, euclidean_knn
from metrics import average_precision
from models import give_me_resnet

if __name__ == '__main__':
    data_path = '../data/'
    device = torch.device('cuda:0')
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = CIFAR10(data_path,
                       train=True,
                       transform=transform,
                       target_transform=None,
                       download=True)
    trainloader = DataLoader(trainset,
                             batch_size=256,
                             shuffle=True,
                             num_workers=0)
    testset = CIFAR10(data_path,
                      train=False,
                      transform=transform,
                      target_transform=None,
                      download=True)
    testloader = DataLoader(testset,
                            batch_size=256,
                            shuffle=True,
                            num_workers=0)
Exemple #2
0
                                  batch_size=64,
                                  num_workers=1,
                                  shuffle=False)
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
    optimizer_S = torch.optim.Adam(net.parameters(), lr=opt.lr_S)
else:
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if opt.dataset == 'cifar10':
        net = resnet.ResNet18().cuda()
        net = nn.DataParallel(net)
        data_test = CIFAR10(opt.data, train=False, transform=transform_test)
    if opt.dataset == 'cifar100':
        net = resnet.ResNet18(num_classes=100).cuda()
        net = nn.DataParallel(net)
        data_test = CIFAR100(opt.data, train=False, transform=transform_test)
    data_test_loader = DataLoader(data_test,
                                  batch_size=opt.batch_size,
                                  num_workers=0)
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
    optimizer_S = torch.optim.SGD(
        net.parameters(), lr=opt.lr_S, momentum=0.9,
        weight_decay=5e-4)  # wh: why use different optimizers for non-MNIST?


def kdloss(y, teacher_scores):
Exemple #3
0
def get_dataset(cfg, fine_tune):
    data_transform = Compose([Resize((32, 32)), ToTensor()])
    mnist_transform = Compose(
        [Resize((32, 32)),
         ToTensor(),
         Lambda(lambda x: swapaxes(x, 1, -1))])
    vade_transform = Compose([ToTensor()])

    if cfg.DATA.DATASET == 'mnist':
        transform = vade_transform if 'vade' in cfg.DIRS.CHKP_PREFIX \
            else mnist_transform

        training_set = MNIST(download=True,
                             root=cfg.DIRS.DATA,
                             transform=transform,
                             train=True)
        val_set = MNIST(download=False,
                        root=cfg.DIRS.DATA,
                        transform=transform,
                        train=False)
        plot_set = copy.deepcopy(val_set)

    elif cfg.DATA.DATASET == 'svhn':
        training_set = SVHN(download=True,
                            root=create_dir(cfg.DIRS.DATA, 'SVHN'),
                            transform=data_transform,
                            split='train')
        val_set = SVHN(download=True,
                       root=create_dir(cfg.DIRS.DATA, 'SVHN'),
                       transform=data_transform,
                       split='test')
        plot_set = copy.deepcopy(val_set)

    elif cfg.DATA.DATASET == 'cifar':
        training_set = CIFAR10(download=True,
                               root=create_dir(cfg.DIRS.DATA, 'CIFAR'),
                               transform=data_transform,
                               train=True)
        val_set = CIFAR10(download=True,
                          root=create_dir(cfg.DIRS.DATA, 'CIFAR'),
                          transform=data_transform,
                          train=False)
        plot_set = copy.deepcopy(val_set)

    elif cfg.DATA.DATASET == 'lines':
        vae = True if 'vae' in cfg.DIRS.CHKP_PREFIX else False
        training_set = LinesDataset(args=cfg,
                                    multiplier=1000,
                                    dataset_type='train',
                                    vae=vae)
        val_set = LinesDataset(args=cfg,
                               multiplier=10,
                               dataset_type='test',
                               vae=vae)
        plot_set = LinesDataset(args=cfg,
                                multiplier=1,
                                dataset_type='plot',
                                vae=vae)

    if 'idec' in cfg.DIRS.CHKP_PREFIX and fine_tune:
        training_set = IdecDataset(training_set)
        val_set = IdecDataset(val_set)
        plot_set = IdecDataset(plot_set)

    return training_set, val_set, plot_set
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4,
                           saturation=0.4,
                           hue=0.4,
                           contrast=0.4),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])
valid_data_preprocess = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])
# 获取训练集、测试集
train_dataset = CIFAR10(root=cifar_10_dir,
                        train=True,
                        transform=train_data_preprocess)
test_dataset = CIFAR10(root=cifar_10_dir,
                       train=False,
                       transform=valid_data_preprocess)

# 获取训练集和测试集的加载器
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=4)
Exemple #5
0
 dataset = args.dataset
 if dataset != "ImageNet":
     for arch in os.listdir(args.dir_path):
         if not os.path.isdir(args.dir_path + "/" + arch):
             continue
         model = StandardModel(dataset, arch, True)
         model.cuda()
         model.eval()
         transform_test = transforms.Compose([
             transforms.Resize(size=(IMAGE_SIZE[dataset][0],
                                     IMAGE_SIZE[dataset][1])),
             transforms.ToTensor(),
         ])
         if dataset == "CIFAR-10":
             test_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                    train=False,
                                    transform=transform_test)
         elif dataset == "CIFAR-100":
             test_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset],
                                     train=False,
                                     transform=transform_test)
         data_loader = torch.utils.data.DataLoader(test_dataset,
                                                   batch_size=100,
                                                   shuffle=True,
                                                   num_workers=0)
         acc_top1, acc_top5 = validate(data_loader, model)
         log.info(
             "val_acc_top1:{:.4f}  val_acc_top5:{:.4f} dataset {} in model {}"
             .format(acc_top1, acc_top5, dataset, arch))
 else:
     for file_name in os.listdir(args.dir_path + "/checkpoints"):
Exemple #6
0
 def prepare_data(self):
     CIFAR10(self.data_path, train=True, download=True)
     CIFAR10(self.data_path, train=False, download=True)
def main(dataset_root: str, mode: str):
    normalize_transform = Compose([
        ToTensor(),
        Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2460, 0.2411, 0.2576)),
    ])
    augment_transform = Compose(
        [RandomHorizontalFlip(),
         RandomCrop(32, padding=4)])
    train_dataset = CIFAR10(
        dataset_root,
        train=True,
        transform=Compose([augment_transform, normalize_transform]),
    )
    validation_dataset = CIFAR10(dataset_root,
                                 train=True,
                                 transform=normalize_transform)
    validation_length = int(math.floor(len(train_dataset) * 0.10))
    train_length = len(train_dataset) - validation_length
    train_dataset, _ = random_split(
        train_dataset,
        lengths=[train_length, validation_length],
        generator=Generator().manual_seed(0),
    )
    _, validation_dataset = random_split(
        validation_dataset,
        lengths=[train_length, validation_length],
        generator=Generator().manual_seed(0),
    )

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=128,
                                  shuffle=True,
                                  num_workers=4,
                                  pin_memory=True)
    validation_dataloader = DataLoader(
        validation_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    if torch.cuda.device_count() == 0:
        device: Optional[torch.device] = torch.device("cpu")
        blocks_per_component = ["rest"]
    elif torch.cuda.device_count() == 1:
        device = torch.device("cuda")
        blocks_per_component = ["rest"]
    else:
        device = None
        # For the demo, just put one block on each device, and all remaing blocks on the
        # last device.
        blocks_per_component = ["1"] * (torch.cuda.device_count() - 1) + [
            "rest"
        ]

    # block_type and architecture correspond to ResNet-32.
    main_nets, aux_nets = resnet_builder.resnet(
        block_type="basic",
        architecture="64,3/128,4/256,6/512,3",
        aux_net_architecture="conv128_bn_conv64_bn_gbpl_fc",
        blocks_per_component=blocks_per_component,
        dataset="cifar10",
        n_classes=10,
    )

    optimizer_constructor = lambda params: SGD(
        params, lr=0.1, momentum=0.9, weight_decay=2e-4)
    # Learning rate schedule for ResNet-50ish on CIFAR10, taken from :
    # https://github.com/tensorflow/models/blob/master/official/r1/resnet/cifar10_main.py#L217
    lr_scheduler_constructor = lambda optimizer: MultiStepLR(
        optimizer, milestones=[91, 136, 182], gamma=0.1)
    loss_function = F.cross_entropy

    if mode == "e2e":
        model = interlocking_backprop.build_e2e_model(
            main_nets, optimizer_constructor, lr_scheduler_constructor,
            loss_function)
    elif mode == "local":
        model = interlocking_backprop.build_local_model(
            main_nets,
            aux_nets,
            optimizer_constructor,
            lr_scheduler_constructor,
            loss_function,
        )
    elif mode == "pairwise":
        model = interlocking_backprop.build_pairwise_model(
            main_nets,
            aux_nets,
            optimizer_constructor,
            lr_scheduler_constructor,
            loss_function,
        )
    elif mode == "3wise":
        model = interlocking_backprop.build_nwise_model(
            main_nets,
            aux_nets,
            optimizer_constructor,
            lr_scheduler_constructor,
            loss_function,
            nwise_communication_distance=3 - 1,
        )
    else:
        raise ValueError(f"Unknown mode {mode}")

    if torch.cuda.device_count() > 1:
        model.enable_model_parallel()
    else:
        model = model.to(device)

    print(
        f"Epoch 0: "
        f"validation accuracy = {_compute_accuracy(validation_dataloader, model):.2f}"
    )

    for epoch in range(100):
        model.train()
        losses = []
        for inputs, targets in train_dataloader:
            loss = model.training_step(inputs, targets)
            losses.append(loss)
        train_loss = (torch.stack([loss.result() for loss in losses],
                                  axis=0).mean().item())
        validation_accuracy = _compute_accuracy(validation_dataloader, model)
        print(f"Epoch {epoch + 1}: "
              f"training loss = {train_loss:.3f} "
              f"validation accuracy = {validation_accuracy:.2f}")
Exemple #8
0
    ])
    x = im_aug(x)
    return x


def test_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(96),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x


train_set = CIFAR10('../data', train=True, transform=train_tf, download=True)
train_data = torch.utils.data.DataLoader(train_set,
                                         batch_size=256,
                                         shuffle=True,
                                         num_workers=4)
valid_set = CIFAR10('../data', train=False, transform=test_tf)
valid_data = torch.utils.data.DataLoader(valid_set,
                                         batch_size=256,
                                         shuffle=False,
                                         num_workers=4)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

writer = SummaryWriter()
Exemple #9
0
def load_data(opt):
    """ Load Data

    Args:
        opt ([type]): Argument Parser

    Raises:
        IOError: Cannot Load Dataset

    Returns:
        [type]: dataloader
    """

    ##
    # LOAD DATA SET
    if opt.dataroot == '':
        opt.dataroot = 'data/{}'.format(opt.dataset)

    if opt.dataset in ['cifar10']:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        classes = {
            'plane': 0,
            'car': 1,
            'bird': 2,
            'cat': 3,
            'deer': 4,
            'dog': 5,
            'frog': 6,
            'horse': 7,
            'ship': 8,
            'truck': 9
        }

        dataset = {}
        dataset['train'] = CIFAR10(root='./data',
                                   train=True,
                                   download=True,
                                   transform=transform)
        dataset['test'] = CIFAR10(root='./data',
                                  train=False,
                                  download=True,
                                  transform=transform)

        dataset['train'].train_data, dataset['train'].train_labels, \
        dataset['test'].test_data, dataset['test'].test_labels = get_cifar_anomaly_dataset(
            trn_img=dataset['train'].train_data,
            trn_lbl=dataset['train'].train_labels,
            tst_img=dataset['test'].test_data,
            tst_lbl=dataset['test'].test_labels,
            abn_cls_idx=classes[opt.anomaly_class]
        )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batchsize,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader

    elif opt.dataset in ['mnist']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        dataset = {}
        dataset['train'] = MNIST(root='./data',
                                 train=True,
                                 download=True,
                                 transform=transform)
        dataset['test'] = MNIST(root='./data',
                                train=False,
                                download=True,
                                transform=transform)
        # print(dataset['train'])
        a, b, c, d = get_mnist_anomaly_dataset(dataset['train'].train_data,
                                               dataset['train'].train_labels,
                                               dataset['test'].test_data,
                                               dataset['test'].test_labels,
                                               opt.anomaly_class, -1)
        # print(a)
        dataset['train'].train_data = a
        dataset['train'].train_labels = b
        dataset['test'].test_data = c
        dataset['test'].test_labels = d
        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batchsize,
                                           shuffle=shuffle[x],
                                           num_workers=opt.workers,
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader

    elif opt.dataset in ['mnist2']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}

        transform = transforms.Compose([
            transforms.Scale(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        dataset = {}
        dataset['train'] = MNIST(root='./data',
                                 train=True,
                                 download=True,
                                 transform=transform)
        dataset['test'] = MNIST(root='./data',
                                train=False,
                                download=True,
                                transform=transform)

        dataset['train'].train_data, dataset['train'].train_labels, \
        dataset['test'].test_data, dataset['test'].test_labels = get_mnist2_anomaly_dataset(
            trn_img=dataset['train'].train_data,
            trn_lbl=dataset['train'].train_labels,
            tst_img=dataset['test'].test_data,
            tst_lbl=dataset['test'].test_labels,
            nrm_cls_idx=opt.anomaly_class,
            proportion=opt.proportion
        )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batchsize,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader
    elif opt.dataset in ['fmnist']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        transform = transforms.Compose([
            transforms.Scale(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        dataset = {}
        dataset['train'] = FashionMNIST(root='./data/fmnist',
                                        train=True,
                                        download=True,
                                        transform=transform)
        dataset['test'] = FashionMNIST(root='./data/fmnist',
                                       train=False,
                                       download=True,
                                       transform=transform)

        dataset['train'].train_data, dataset['train'].train_labels, \
        dataset['test'].test_data, dataset['test'].test_labels = get_mnist2_anomaly_dataset(
            trn_img=dataset['train'].train_data,
            trn_lbl=dataset['train'].train_labels,
            tst_img=dataset['test'].test_data,
            tst_lbl=dataset['test'].test_labels,
            nrm_cls_idx=opt.anomaly_class,
            proportion=opt.proportion
        )

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batchsize,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader
    else:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}
        transform = transforms.Compose([
            transforms.Scale(opt.isize),
            transforms.CenterCrop(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        dataset = {
            x: ImageFolder(os.path.join(opt.dataroot, x), transform)
            for x in splits
        }

        dataloader = {
            x: torch.utils.data.DataLoader(dataset=dataset[x],
                                           batch_size=opt.batchsize,
                                           shuffle=shuffle[x],
                                           num_workers=int(opt.workers),
                                           drop_last=drop_last_batch[x])
            for x in splits
        }
        return dataloader
Exemple #10
0
from fast_adv.models.cifar10.model_attention import wide_resnet
from fast_adv.utils import AverageMeter, save_checkpoint, requires_grad_, NormalizedModel, VisdomLogger
from fast_adv.attacks import DDN, DeepFool

image_mean = torch.tensor([0.491, 0.482, 0.447]).view(1, 3, 1, 1)
image_std = torch.tensor([0.247, 0.243, 0.262]).view(1, 3, 1, 1)

DEVICE = torch.device('cuda:0' if (torch.cuda.is_available()) else 'cpu')
test_transform = transforms.Compose([
    transforms.ToTensor(),
])

input = '../defenses/data/cifar10'
#path='/media/wanghao/000F5F8400087C68/CYJ-5-29/DDN/fast_adv/attacks/DeepFool'
test_set = data.Subset(
    CIFAR10(input, train=False, transform=test_transform, download=True),
    list(range(0, 1000)))
#test_set =CIFAR10(input, train=False, transform=test_transform, download=True)

test_loader = data.DataLoader(test_set,
                              batch_size=5,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=True)

m = wide_resnet(num_classes=10, depth=28, widen_factor=10, dropRate=0.3)
# model = NormalizedModel(model=m, mean=image_mean, std=image_std).to(DEVICE)  # keep images in the [0, 1] range
# model_dict = torch.load('../defenses/weights/cifar10/cifar10_80.pth')
# model.load_state_dict(model_dict)
# ########
# model2=NormalizedModel(model=m, mean=image_mean, std=image_std).to(DEVICE)  # keep images in the [0, 1] range
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs

# L2正则化 lambda项

def data_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(96),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4) # 增加正则项
criterion = nn.CrossEntropyLoss()

from utils import train
train(net, train_data, test_data, 20, optimizer, criterion)
Exemple #12
0
def show_worst(model, dataset, loss, n=20):
    to_tensor = transforms.ToTensor()
    dl = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    loss_dict = {}
    for i, (inputs, labels) in enumerate(dl):
        outputs = model(inputs)
        loss_dict[i] = loss(outputs, labels).item()

    loss_dict = {
        k: v
        for k, v in sorted(loss_dict.items(), key=lambda item: -item[1])
    }

    for i in list(loss_dict)[:n]:
        input_img = to_tensor(dataset.data[i]).unsqueeze_(0)
        outputs = model(input_img)

        pred = outputs.cpu().detach().topk(k=3).indices.numpy()[0, :]
        title = "Class %d" % dataset.targets[i]
        title += " Predicted " + str(pred)

        plt.title(title)
        plt.imshow(dataset.data[i])
        plt.show()


criterion = nn.CrossEntropyLoss()
model = pt_models.ConvolutionalModel((32, 32), 3, [16, 32], [256, 128], 10)
ds_train = CIFAR10(DATA_DIR, train=False, transform=transforms.ToTensor())
show_worst(model, ds_train, criterion)
Exemple #13
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
    parser.add_argument(
        "-j",
        "--workers",
        default=2,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--epochs",
        default=90,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--start-epoch",
        default=1,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "-b",
        "--batch-size-test",
        default=256,
        type=int,
        metavar="N",
        help=
        "mini-batch size for test dataset (default: 256), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument(
        "--sample-rate",
        default=0.04,
        type=float,
        metavar="SR",
        help="sample rate used for batch construction (default: 0.005)",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.1,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument("--momentum",
                        default=0.9,
                        type=float,
                        metavar="M",
                        help="SGD momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=0,
        type=float,
        metavar="W",
        help="SGD weight decay",
        dest="weight_decay",
    )
    parser.add_argument(
        "-p",
        "--print-freq",
        default=10,
        type=int,
        metavar="N",
        help="print frequency (default: 10)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "-e",
        "--evaluate",
        dest="evaluate",
        action="store_true",
        help="evaluate model on validation set",
    )
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="seed for initializing training. ")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=1.5,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=10.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--secure-rng",
        action="store_true",
        default=False,
        help=
        "Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )

    parser.add_argument(
        "--checkpoint-file",
        type=str,
        default="checkpoint",
        help="path to save check points",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../cifar10",
        help="Where CIFAR10 is/will be stored",
    )
    parser.add_argument("--log-dir",
                        type=str,
                        default="",
                        help="Where Tensorboard log will be stored")
    parser.add_argument(
        "--optim",
        type=str,
        default="SGD",
        help="Optimizer to use (Adam, RMSprop, SGD)",
    )
    parser.add_argument("--lr-schedule",
                        type=str,
                        choices=["constant", "cos"],
                        default="cos")

    args = parser.parse_args()

    if args.disable_dp and args.n_accumulation_steps > 1:
        raise ValueError("Virtual steps only works with enabled DP")

    # The following few lines, enable stats gathering about the run
    # 1. where the stats should be logged
    stats.set_global_summary_writer(
        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir)))
    # 2. enable stats
    stats.add(
        # stats about gradient norms aggregated for all layers
        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
        # stats about gradient norms per layer
        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
        # stats about clipping
        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
        # stats on training accuracy
        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
        # stats on validation accuracy
        stats.Stat(stats.StatType.TEST, "accuracy"),
    )

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {"clip_per_layer": False, "enable_stat": True}

    if args.secure_rng:
        try:
            import torchcsprng as prng
        except ImportError as e:
            msg = (
                "To use secure RNG, you must install the torchcsprng package! "
                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
            )
            raise ImportError(msg) from e

        generator = prng.create_random_device_generator("/dev/urandom")

    else:
        generator = None

    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ]
    train_transform = transforms.Compose(
        augmentations + normalize if args.disable_dp else normalize)

    test_transform = transforms.Compose(normalize)

    train_dataset = CIFAR10(root=args.data_root,
                            train=True,
                            download=True,
                            transform=train_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        num_workers=args.workers,
        generator=generator,
        batch_sampler=UniformWithReplacementSampler(
            num_samples=len(train_dataset),
            sample_rate=args.sample_rate,
            generator=generator,
        ),
    )

    test_dataset = CIFAR10(root=args.data_root,
                           train=False,
                           download=True,
                           transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size_test,
        shuffle=False,
        num_workers=args.workers,
    )

    best_acc1 = 0
    device = torch.device(args.device)
    model = convnet(num_classes=10)
    model = model.to(device)

    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(
            "Optimizer not recognized. Please check spelling")

    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
            model,
            sample_rate=args.sample_rate * args.n_accumulation_steps,
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=args.max_per_sample_grad_norm,
            secure_rng=args.secure_rng,
            **clipping,
        )
        privacy_engine.attach(optimizer)

    for epoch in range(args.start_epoch, args.epochs + 1):
        if args.lr_schedule == "cos":
            lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch /
                                             (args.epochs + 1)))
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

        train(args, model, train_loader, optimizer, epoch, device)
        top1_acc = test(args, model, test_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = top1_acc > best_acc1
        best_acc1 = max(top1_acc, best_acc1)

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": "Convnet",
                "state_dict": model.state_dict(),
                "best_acc1": best_acc1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            filename=args.checkpoint_file + ".tar",
        )
Exemple #14
0
print(opt)

if opt.cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)

device = torch.device("cuda" if opt.cuda else "cpu")

print('===> Loading datasets')

train_set_img = CIFAR10('dataset',
                        train=True,
                        download=True,
                        transform=transforms.Compose([
                            transforms.Grayscale(1),
                            transforms.Resize(8),
                            transforms.ToTensor()
                        ]))

train_set_target = CIFAR10('dataset',
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.Grayscale(1),
                               transforms.Resize(32),
                               transforms.ToTensor()
                           ]))

test_set_img = CIFAR10('dataset',
                       train=False,
    torch.manual_seed(36246)

    # Identify whether a CUDA GPU is available to run the analyses on
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # The parameters to run the model with
    no_cycles = 10
    train_epochs = 90
    prune_epochs = 10
    in_shape = (3, 32, 32)
    train_lr = 0.001
    prune_lr = 0.0001
    batch_size = 256

    training_data = CIFAR10(root="data",
                            train=True,
                            download=True,
                            transform=data_managers.train_transform)

    testing_data = CIFAR10(root="data",
                           train=False,
                           download=True,
                           transform=data_managers.test_transform)

    train_dataloader = DataLoader(training_data,
                                  batch_size=batch_size,
                                  shuffle=True)
    testing_dataloader = DataLoader(testing_data,
                                    batch_size=batch_size,
                                    shuffle=True)

    # Build the model, initialized with one column
Exemple #16
0
    def data_key(self, data):
        if isinstance(data, (list, tuple)):
            return data
        else:
            return (data, )

    def each_generate(self, data):
        samples = [(torch.clamp(sample, -1, 1) + 1) / 2
                   for sample in data[0:10]]
        samples = torch.cat(samples, dim=-1)
        self.writer.add_image("samples", samples, self.step_id)


if __name__ == "__main__":
    mnist = CIFAR10("examples/", download=False, transform=ToTensor())
    data = EnergyDataset(mnist)

    energy = ConvEnergy(depth=20)
    integrator = Langevin(rate=1,
                          noise=0.01,
                          steps=20,
                          clamp=None,
                          max_norm=None)

    training = CIFAR10EnergyTraining(
        energy,
        data,
        network_name="classifier-mnist-ebm/cifar-plain",
        device="cuda:0",
        integrator=integrator,
    def __init__(self, tot_num_tasks, dataset, inner_batch_size, protocol):
        """
        Args:
            num_samples_per_class: num samples to generate "per class" in one batch
            batch_size: size of meta batch size (e.g. number of functions)
        """
        self.img_size = IMAGE_SIZE[dataset]
        self.dataset = dataset

        if protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II:
            self.model_names = MODELS_TRAIN_STANDARD[self.dataset]
        elif protocol == SPLIT_DATA_PROTOCOL.TRAIN_II_TEST_I:
            self.model_names = MODELS_TEST_STANDARD[self.dataset]
        elif protocol == SPLIT_DATA_PROTOCOL.TRAIN_ALL_TEST_ALL:
            self.model_names = MODELS_TRAIN_STANDARD[
                self.dataset] + MODELS_TEST_STANDARD[self.dataset]

        self.model_dict = {}
        for arch in self.model_names:
            if StandardModel.check_arch(arch, dataset):
                model = StandardModel(dataset, arch, no_grad=False).eval()
                if dataset != "ImageNet":
                    model = model.cuda()
                self.model_dict[arch] = model
        is_train = True
        preprocessor = DataLoaderMaker.get_preprocessor(
            IMAGE_SIZE[dataset], is_train)
        if dataset == "CIFAR-10":
            train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                    train=is_train,
                                    transform=preprocessor)
        elif dataset == "CIFAR-100":
            train_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset],
                                     train=is_train,
                                     transform=preprocessor)
        elif dataset == "MNIST":
            train_dataset = MNIST(IMAGE_DATA_ROOT[dataset],
                                  train=is_train,
                                  transform=preprocessor)
        elif dataset == "FashionMNIST":
            train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset],
                                         train=is_train,
                                         transform=preprocessor)
        elif dataset == "TinyImageNet":
            train_dataset = TinyImageNet(IMAGE_DATA_ROOT[dataset],
                                         preprocessor,
                                         train=is_train)
        elif dataset == "ImageNet":
            preprocessor = DataLoaderMaker.get_preprocessor(
                IMAGE_SIZE[dataset], is_train, center_crop=True)
            sub_folder = "/train" if is_train else "/validation"  # Note that ImageNet uses pretrainedmodels.utils.TransformImage to apply transformation
            train_dataset = ImageFolder(IMAGE_DATA_ROOT[dataset] + sub_folder,
                                        transform=preprocessor)
        self.train_dataset = train_dataset
        self.total_num_images = len(train_dataset)
        self.all_tasks = dict()
        all_images_indexes = np.arange(self.total_num_images).tolist()
        for i in range(tot_num_tasks):
            self.all_tasks[i] = {
                "image": random.sample(all_images_indexes, inner_batch_size),
                "arch": random.choice(list(self.model_dict.keys()))
            }
Exemple #18
0
import time
import itertools
import numpy as np
import scipy.sparse
import torch
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from sklearn.utils import gen_batches
from sklearn.svm import LinearSVC
from pcanet import PCANet

#%%
# load data
n_train = 1000
n_test = 1000
train_set = CIFAR10('datasets', download=True)
train_set_X = torch.from_numpy(train_set.data[::50]).float().div(255)
train_set_y = np.asarray(train_set.targets[::50])

test_set = CIFAR10('datasets', False, download=True)
test_set_X = torch.from_numpy(test_set.data[::10]).float().div(255.0)
test_set_y = np.asarray(test_set.targets[::10])

# %%
# visualization
nclasses = 10  # number of classes to visualize
nexamples = 10  # number of examples for each class

# Chosing indices from training set images
img_idx = [
    np.where(train_set_y == class_id)[0][0:nexamples]
Exemple #19
0
from resnet import *
from arguments import parse_args
from utils import get_optimizer, get_transforms

args = parse_args()

device = f'cuda:{args.gpu_id[0]}'
args.device = torch.device(device)

print(args)
print(torch.cuda.get_device_name(0))

transform_train, transform_test = get_transforms()

trainset = CIFAR10(root=args.data_dir,
                   train=True,
                   download=True,
                   transform=transform_train)
trainloader = DataLoader(trainset,
                         batch_size=args.batch_size,
                         shuffle=True,
                         num_workers=args.num_workers)

testset = CIFAR10(root=args.data_dir,
                  train=False,
                  download=True,
                  transform=transform_test)
testloader = DataLoader(testset,
                        batch_size=100,
                        shuffle=False,
                        num_workers=args.num_workers)
# https://stackoverflow.com/questions/38834378/path-to-a-directory-as-argparse-argument
def dir_path(string):
    if os.path.isdir(string):
        return string
    else:
        raise NotADirectoryError(string)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()

    parser.add_argument('--path',
                        type=dir_path,
                        required=True,
                        help='Download cifar10 and cifar100 to this path.')

    args = parser.parse_args()
    c10_path = args.path + '/cifar10'
    c100_path = args.path + '/cifar100'

    if not os.path.exists(c10_path):
        os.makedirs(c10_path)

    if not os.path.exists(c100_path):
        os.makedirs(c100_path)

    CIFAR10(c10_path, download=True)
    CIFAR100(c100_path, download=True)
Exemple #21
0
    RandomHorizontalFlip(),
    CIFAR10Policy(),
    ToTensor(),
    Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    Cutout(1, 16),
])

test_transform = Compose([
    ToTensor(),
    Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
])

# Dataset

data_home = "datasets/CIFAR10"
ds = CIFAR10(data_home, train=True, download=True)
ds = train_test_split(ds, test_ratio=0.5)[1]
ds_train, ds_val = train_test_split(
    ds,
    test_ratio=0.02,
    transform=train_transform,
    test_transform=test_transform,
)
ds_test = CIFAR10(data_home,
                  train=False,
                  download=True,
                  transform=test_transform)

net = LeNet5()
# net = efficientnet_b0(num_classes=10, dropout=0.3, drop_connect=0.2)
criterion = nn.CrossEntropyLoss()
Exemple #22
0
    def test_fid(self):
        assert self.config.data.dataset == 'CELEBA'
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True,
                                 transform=transform)
        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False,
                                 num_workers=2)

        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        get_data_stats = False
        manual = False
        if get_data_stats:
            data_images = []
            for _, (X, y) in enumerate(test_loader):
                X = X.to(self.config.device)
                X = X + (torch.rand_like(X) - 0.5) / 128.
                data_images.extend(X / 2. + 0.5)
                if len(data_images) > 10000:
                    break

            if not os.path.exists(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images')):
                os.makedirs(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images'))
            logging.info("Saving data images")
            for i, image in enumerate(data_images):
                save_image(image,
                           os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images', '{}.png'.format(i)))
            logging.info("Images saved. Calculating fid statistics now")
            fid.calculate_data_statics(os.path.join(self.args.run, 'datasets', 'celeba140_fid', 'raw_images'),
                                       os.path.join(self.args.run, 'datasets', 'celeba140_fid'), 50, True, 2048)


        else:
            if manual:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
                decoder = Decoder(self.config).to(self.config.device)
                decoder.eval()
                if self.config.training.algo == 'vae':
                    encoder = Encoder(self.config).to(self.config.device)
                    encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])
                elif self.config.training.algo == 'ssm':
                    score = Score(self.config).to(self.config.device)
                    imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                    imp_encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])
                    score.load_state_dict(states[2])
                elif self.config.training.algo in ['spectral', 'stein']:
                    from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
                    imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                    imp_encoder.load_state_dict(states[0])
                    decoder.load_state_dict(states[1])

                all_samples = []
                logging.info("Generating samples")
                for i in range(100):
                    with torch.no_grad():
                        z = torch.randn(100, self.config.model.z_dim, device=self.config.device)
                        samples, _ = decoder(z)
                        samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                               self.config.data.image_size)
                        all_samples.extend(samples / 2. + 0.5)

                if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images')):
                    os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images'))
                logging.info("Images generated. Saving images")
                for i, image in enumerate(all_samples):
                    save_image(image, os.path.join(self.args.log, 'samples', 'raw_images', '{}.png'.format(i)))
                logging.info("Generating fid statistics")
                fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images'),
                                           os.path.join(self.args.log, 'samples'), 50, True, 2048)
                logging.info("Statistics generated.")
            else:
                for iter in range(1, 11):
                    states = torch.load(os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format(iter)),
                                        map_location=self.config.device)
                    decoder = Decoder(self.config).to(self.config.device)
                    decoder.eval()
                    if self.config.training.algo == 'vae':
                        encoder = Encoder(self.config).to(self.config.device)
                        encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])
                    elif self.config.training.algo == 'ssm':
                        score = Score(self.config).to(self.config.device)
                        imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                        imp_encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])
                        score.load_state_dict(states[2])
                    elif self.config.training.algo in ['spectral', 'stein']:
                        from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
                        imp_encoder = ImplicitEncoder(self.config).to(self.config.device)
                        imp_encoder.load_state_dict(states[0])
                        decoder.load_state_dict(states[1])

                    all_samples = []
                    logging.info("Generating samples")
                    for i in range(100):
                        with torch.no_grad():
                            z = torch.randn(100, self.config.model.z_dim, device=self.config.device)
                            samples, _ = decoder(z)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            all_samples.extend(samples / 2. + 0.5)

                    if not os.path.exists(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter))):
                        os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)))
                    else:
                        shutil.rmtree(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)))
                        os.makedirs(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)))

                    if not os.path.exists(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter))):
                        os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)))
                    else:
                        shutil.rmtree(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)))
                        os.makedirs(os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)))

                    logging.info("Images generated. Saving images")
                    for i, image in enumerate(all_samples):
                        save_image(image, os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter),
                                                       '{}.png'.format(i)))
                    logging.info("Generating fid statistics")
                    fid.calculate_data_statics(os.path.join(self.args.log, 'samples', 'raw_images_{}0k'.format(iter)),
                                               os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter)),
                                               50, True, 2048)
                    logging.info("Statistics generated.")
                    fid_number = fid.calculate_fid_given_paths([
                        'run/datasets/celeba140_fid/celeba_test.npz',
                        os.path.join(self.args.log, 'samples', 'statistics_{}0k'.format(iter), 'celeba_test.npz')]
                        , 50, True, 2048)
                    logging.info("Number of iters: {}0k, FID: {}".format(iter, fid_number))
Exemple #23
0
def process_CIFAR10(if_balanced, if_weighted, batch_size = 16, num_workers=2):
    # Notice that we apply the same mean and std normalization calculated on train, to both the train and test datasets.

    transform_train = transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.RandomCrop(32, padding=4),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.4914, 0.4822, 0.4465], 
                                            [0.247, 0.243, 0.261])
                                        ])
    
    transform_test = transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.4914, 0.4822, 0.4465], 
                                            [0.247, 0.243, 0.261])
                                        ])

    trainset = CIFAR10(root='./cifar10', train=True, transform=None, target_transform=None, download=True)
    testset = CIFAR10(root='./cifar10', train=False, transform=None, target_transform=None, download=True)

    #trainset = MapDataset(trainset, transform_train)
    #testset = MapDataset(testset, transform_test)

    #classes = trainset.classes
    classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}
    
    def get_class_i(x, y, i):
      """
      x: trainset.train_data or testset.test_data
      y: trainset.train_labels or testset.test_labels
      i: class label, a number between 0 to 9
      return: x_i
      """
      # Convert to a numpy array
      y = np.array(y)
      # Locate position of labels that equal to i
      pos_i = np.argwhere(y == i)
      # Convert the result into a 1-D list
      pos_i = list(pos_i[:,0])
      # Collect all data that match the desired label
      x_i = [x[j] for j in pos_i]

      return x_i
    
    x_train  = trainset.data
    y_train  = trainset.targets
    x_test  = trainset.data
    y_test  = trainset.targets


    test_data = get_class_i(x_test, y_test, classDict['cat']) + get_class_i(x_test, y_test, classDict['dog']) + \
                get_class_i(x_test, y_test, classDict['car']) + get_class_i(x_test, y_test, classDict['truck'])
    test_labels = [0 for j in range(2000)] + [1 for j in range(2000)]
    test_weights = [1 for j in range(4000)] 
    testset = list(zip(test_data, test_labels, test_weights))
    test = MapDataset(testset, transform_test)

    if if_balanced == True:
        train_data = get_class_i(x_train, y_train, classDict['cat']) + get_class_i(x_train, y_train, classDict['dog']) + \
                     get_class_i(x_train, y_train, classDict['car']) + get_class_i(x_train, y_train, classDict['truck'])
        train_labels = [0 for j in range(10000)] + [1 for j in range(10000)]
        train_weights = [1 for j in range(20000)] 
    else:
        train_data = random.sample(get_class_i(x_train, y_train, classDict['cat']), 500) + get_class_i(x_train, y_train, classDict['dog']) + \
                     random.sample(get_class_i(x_train, y_train, classDict['car']), 500) + get_class_i(x_train, y_train, classDict['truck'])
        train_labels = [0 for j in range(5500)] + [1 for j in range(5500)]

        if if_weighted == True:
            train_weights = [5.5 for j in range(500)] + [0.55 for j in range(5000)] + [5.5 for j in range(500)] + [0.55 for j in range(5000)]
        else:
            train_weights = [1 for j in range(11000)]
    
    trainset = list(zip(train_data, train_labels, train_weights))
    train = MapDataset(trainset, transform_train)
    

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    
    return train_loader,test_loader
Exemple #24
0
    def train(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=True, download=True,
                              transform=transform)
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=True, download=True,
                            transform=transform)
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = indices[:int(num_items * 0.8)], indices[int(num_items * 0.8):]
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            train_indices, test_indices = indices[:int(num_items * 0.7)], indices[
                                                                          int(num_items * 0.7):int(num_items * 0.8)]
            test_dataset = Subset(dataset, test_indices)
            dataset = Subset(dataset, train_indices)

        dataloader = DataLoader(dataset, batch_size=self.config.training.batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=True,
                                 num_workers=2)
        test_iter = iter(test_loader)
        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        tb_path = os.path.join(self.args.run, 'tensorboard', self.args.doc)
        if os.path.exists(tb_path):
            shutil.rmtree(tb_path)

        tb_logger = tensorboardX.SummaryWriter(log_dir=tb_path)
        decoder = MLPDecoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
            else Decoder(self.config).to(self.config.device)
        if self.config.training.algo == 'vae':
            encoder = MLPEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else Encoder(self.config).to(self.config.device)
            optimizer = self.get_optimizer(itertools.chain(encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])
        elif self.config.training.algo == 'ssm':
            score = MLPScore(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' else \
                Score(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else ImplicitEncoder(self.config).to(self.config.device)

            opt_ae = optim.RMSprop(itertools.chain(decoder.parameters(), imp_encoder.parameters()),
                                   lr=self.config.optim.lr)
            opt_score = optim.RMSprop(score.parameters(), lr=self.config.optim.lr)
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                score.load_state_dict(states[2])
                opt_ae.load_state_dict(states[3])
                opt_score.load_state_dict(states[4])
        elif self.config.training.algo in ['spectral', 'stein']:
            from models.kernel_score_estimators import SpectralScoreEstimator, SteinScoreEstimator
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device) if self.config.data.dataset == 'MNIST' \
                else ImplicitEncoder(self.config).to(self.config.device)
            estimator = SpectralScoreEstimator() if self.config.training.algo == 'spectral' else SteinScoreEstimator()
            optimizer = self.get_optimizer(itertools.chain(imp_encoder.parameters(), decoder.parameters()))
            if self.args.resume_training:
                states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'))
                imp_encoder.load_state_dict(states[0])
                decoder.load_state_dict(states[1])
                optimizer.load_state_dict(states[2])

        step = 0
        best_validation_loss = np.inf
        validation_losses = []
        recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian'

        for _ in range(self.config.training.n_epochs):
            for _, (X, y) in enumerate(dataloader):
                decoder.train()
                X = X.to(self.config.device)
                if self.config.data.dataset == 'CELEBA':
                    X = X + (torch.rand_like(X) - 0.5) / 128.
                elif self.config.data.dataset == 'MNIST':
                    eps = torch.rand_like(X)
                    X = (eps <= X).float()

                if self.config.training.algo == 'vae':
                    encoder.train()
                    loss, *_ = elbo(encoder, decoder, X, recon_type)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                elif self.config.training.algo == 'ssm':
                    imp_encoder.train()
                    loss, ssm_loss, *_ = elbo_ssm(imp_encoder, decoder, score, opt_score, X, recon_type,
                                                  training=True, n_particles=self.config.model.n_particles)
                    opt_ae.zero_grad()
                    loss.backward()
                    opt_ae.step()
                elif self.config.training.algo in ['spectral', 'stein']:
                    imp_encoder.train()
                    loss = elbo_kernel(imp_encoder, decoder, estimator, X, recon_type,
                                       n_particles=self.config.model.n_particles)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                if step % 10 == 0:
                    try:
                        test_X, _ = next(test_iter)
                    except:
                        test_iter = iter(test_loader)
                        test_X, _ = next(test_iter)

                    test_X = test_X.to(self.config.device)
                    if self.config.data.dataset == 'CELEBA':
                        test_X = test_X + (torch.rand_like(test_X) - 0.5) / 128.
                    elif self.config.data.dataset == 'MNIST':
                        test_eps = torch.rand_like(test_X)
                        test_X = (test_eps <= test_X).float()

                    decoder.eval()
                    if self.config.training.algo == 'vae':
                        encoder.eval()
                        with torch.no_grad():
                            test_loss, *_ = elbo(encoder, decoder, test_X, recon_type)
                            logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item()))
                    elif self.config.training.algo == 'ssm':
                        imp_encoder.eval()
                        test_loss, *_ = elbo_ssm(imp_encoder, decoder, score, None, test_X, recon_type, training=False)
                        logging.info("loss: {}, ssm_loss: {}, test_loss: {}".format(loss.item(), ssm_loss.item(),
                                                                                    test_loss.item()))
                        z = imp_encoder(test_X)
                        tb_logger.add_histogram('z_X', z, global_step=step)
                    elif self.config.training.algo in ['spectral', 'stein']:
                        imp_encoder.eval()
                        with torch.no_grad():
                            test_loss = elbo_kernel(imp_encoder, decoder, estimator, test_X, recon_type, 10)

                            logging.info("loss: {}, test_loss: {}".format(loss.item(), test_loss.item()))

                    validation_losses.append(test_loss.item())
                    tb_logger.add_scalar('loss', loss, global_step=step)
                    tb_logger.add_scalar('test_loss', test_loss, global_step=step)

                    if self.config.training.algo == 'ssm':
                        tb_logger.add_scalar('ssm_loss', ssm_loss, global_step=step)

                if step % 500 == 0:
                    with torch.no_grad():
                        z = torch.randn(100, self.config.model.z_dim, device=X.device)
                        decoder.eval()
                        if self.config.data.dataset == 'CELEBA':
                            samples, _ = decoder(z)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            image_grid = make_grid(samples, 10)
                            image_grid = torch.clamp(image_grid / 2. + 0.5, 0.0, 1.0)
                            data_grid = make_grid(X[:100], 10)
                            data_grid = torch.clamp(data_grid / 2. + 0.5, 0.0, 1.0)
                        elif self.config.data.dataset == 'MNIST':
                            samples_logits = decoder(z)
                            samples = torch.sigmoid(samples_logits)
                            samples = samples.view(100, self.config.data.channels, self.config.data.image_size,
                                                   self.config.data.image_size)
                            image_grid = make_grid(samples, 10)
                            data_grid = make_grid(X[:100], 10)

                        tb_logger.add_image('samples', image_grid, global_step=step)
                        tb_logger.add_image('data', data_grid, global_step=step)

                        if len(validation_losses) != 0:
                            validation_loss = sum(validation_losses) / len(validation_losses)
                            if validation_loss < best_validation_loss:
                                best_validation_loss = validation_loss
                                validation_losses = []
                            # else:
                            #     return 0

                if (step + 1) % 10000 == 0:
                    if self.config.training.algo == 'vae':
                        states = [
                            encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict()
                        ]
                    elif self.config.training.algo == 'ssm':
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            score.state_dict(),
                            opt_ae.state_dict(),
                            opt_score.state_dict()
                        ]
                    elif self.config.training.algo in ['spectral', 'stein']:
                        states = [
                            imp_encoder.state_dict(),
                            decoder.state_dict(),
                            optimizer.state_dict()
                        ]
                    torch.save(states,
                               os.path.join(self.args.log, 'checkpoint_{}0k.pth'.format((step + 1) // 10000)))
                    torch.save(states, os.path.join(self.args.log, 'checkpoint.pth'))

                step += 1
                if step >= self.config.training.n_iters:
                    return 0
Exemple #25
0
def load_cifar(inds=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
               feature_extractor='hog',
               composite=False,
               composite_labels=None,
               normalize=False):

    start = timer()
    train_set = CIFAR10('.cifar/',
                        train=True,
                        download=True,
                        transform=torchvision.transforms.ToTensor())
    test_set = CIFAR10('.cifar/',
                       train=False,
                       download=True,
                       transform=torchvision.transforms.ToTensor())

    trainX = train_set.data
    trainY = np.array(train_set.targets)

    testX = test_set.data
    testY = np.array(test_set.targets)

    print("using only the following indices in CIFAR: ", str(inds))

    train_inds = np.where(np.isin(trainY, inds))[0]
    test_inds = np.where(np.isin(testY, inds))[0]

    trainX = trainX[train_inds, :]
    trainLabels = trainY[train_inds]
    testX = testX[test_inds, :]
    testLabels = testY[test_inds]

    # hogify
    if feature_extractor == 'hog':
        print("using HOG as feature extractor.")
        trainX = np.array([
            hog(dat, multichannel=True, feature_vector=True) for dat in trainX
        ])
        testX = np.array([
            hog(dat, multichannel=True, feature_vector=True) for dat in testX
        ])

    # resnet
    elif feature_extractor == 'resnet':
        print("using ResNet50 as feature extractor.")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        extractor = torchvision.models.resnet50(pretrained=True)
        extractor = nn.Sequential(*list(extractor.children())[:-1])
        extractor = extractor.to(device)

        train_feats = []
        test_feats = []

        trainloader = DataLoader(train_set, batch_size=128)
        testloader = DataLoader(test_set, batch_size=128)
        extractor.eval()
        with torch.no_grad():
            for data, label in trainloader:
                data = data.to(device)
                feats = extractor(data)
                feats = feats.cpu().reshape(-1, 2048)
                train_feats.append(feats)

            for data, label in testloader:
                data = data.to(device)
                feats = extractor(data)
                feats = feats.cpu().reshape(-1, 2048)
                test_feats.append(feats)
        trainX = torch.cat(train_feats, dim=0)
        testX = torch.cat(test_feats, dim=0)

    # no feature extraction. just flatten and return.
    else:
        all_dims = np.product(trainX.shape[1:])
        trainX = trainX.reshape(-1, all_dims)
        testX = testX.reshape(-1, all_dims)

    # normalize
    if normalize:
        print("normalizing...", end='')
        scaler = StandardScaler()
        trainX = scaler.fit_transform(trainX)
        testX = scaler.transform(testX)
        trainX = torch.from_numpy(trainX)
        testX = torch.from_numpy(testX)

    # composite labels
    if composite:
        print("using the following composite labels: " +
              str(composite_labels.keys()))
        trainY = np.zeros(trainLabels.shape)
        testY = np.zeros(testLabels.shape)

        for label in composite_labels:
            label_set = composite_labels[label]
            trainY[np.where(np.isin(trainLabels, label_set))] = label
            testY[np.where(np.isin(testLabels, label_set))] = label

    else:
        trainY = trainLabels
        testY = testLabels

    # cast everything to tensors
    trainX = torch.as_tensor(trainX)
    testX = torch.as_tensor(testX)
    trainY = torch.as_tensor(trainY)
    testY = torch.as_tensor(testY)
    trainLabels = torch.as_tensor(trainLabels)
    testLabels = torch.as_tensor(testLabels)

    print("trainX shape: ", trainX.shape)
    print("trainY.shape: ", trainY.shape)
    print("done. elapsed: ", timer() - start)

    return trainX.float(), trainY, testX.float(
    ), testY, trainLabels, testLabels
Exemple #26
0
    def test_ais(self):
        assert self.config.data.dataset == 'MNIST'
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True,
                                 transform=transform)
        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        # test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False,
        #                          num_workers=2)

        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                                 num_workers=2)

        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
        decoder = MLPDecoder(self.config).to(self.config.device)
        if self.config.training.algo == 'vae':
            encoder = MLPEncoder(self.config).to(self.config.device)
            encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
        elif self.config.training.algo == 'ssm':
            score = MLPScore(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
            score.load_state_dict(states[2])
        elif self.config.training.algo in ['spectral', 'stein']:
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])

        recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian'

        def recon_energy(X, z):
            if recon_type is 'gaussian':
                mean_x, logstd_x = decoder(z)
                recon = (X - mean_x) ** 2 / (2. * (2 * logstd_x).exp()) + np.log(2. * np.pi) / 2. + logstd_x
                recon = recon.sum(dim=(1, 2, 3))
            elif recon_type is 'bernoulli':
                x_logits = decoder(z)
                recon = F.binary_cross_entropy_with_logits(input=x_logits, target=X, reduction='none')
                recon = recon.sum(dim=[1, 2, 3])
            return recon

        from evaluations.ais import AISLatentVariableModels
        ais = AISLatentVariableModels(recon_energy,
                                      self.config.model.z_dim,
                                      self.config.device, n_Ts=1000)

        total_l = 0.
        total_n = 0
        for _, (X, y) in enumerate(test_loader):
            X = X.to(self.config.device)
            if self.config.data.dataset == 'CELEBA':
                X = X + (torch.rand_like(X) - 0.5) / 128.
            elif self.config.data.dataset == 'MNIST':
                eps = torch.rand_like(X)
                X = (eps <= X).float()

            ais_lb = ais.ais(X).mean().item()
            total_l += ais_lb * X.shape[0]
            total_n += X.shape[0]
            print('current ais lb: {}, mean ais lb: {}'.format(ais_lb, total_l / total_n))
Exemple #27
0
def get_dataset(name='CIFAR10',
                batch_size=128,
                test_batch_size=1024,
                root='.',
                device=None,
                seed=0):

    if name == 'CIFAR10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        refset = CIFAR10(root=root + '/CIFAR10_data',
                         train=True,
                         download=True,
                         transform=transform_test)
        trainset = CIFAR10(root=root + '/CIFAR10_data',
                           train=True,
                           download=True,
                           transform=transform_train)
        testset = CIFAR10(root=root + '/CIFAR10_data',
                          train=False,
                          download=True,
                          transform=transform_test)
        
    elif name == 'FMNIST':
        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5,), (0.5,)),
        ])
        refset = FashionMNIST(root + '/F_MNIST_data/',
                              download=True,
                              train=True,
                              transform=transform)
        trainset = FashionMNIST(root + '/F_MNIST_data/',
                                download=True,
                                train=True,
                                transform=transform)
        testset = FashionMNIST(root + '/F_MNIST_data/',
                               download=True,
                               train=False,
                               transform=transform)

    elif name == 'CIFAR100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR100(root=root + '/CIFAR100/',
                            train=True,
                            download=True,
                            transform=transform_train)
        testset = CIFAR100(root=root + '/CIFAR100/',
                           train=False,
                           download=False,
                           transform=transform_test)
        refset = None


    else:
        raise RuntimeError('Unknown dataset')

    n_dataset = len(trainset)
    n_train = int(1 * n_dataset)
    n_val = n_dataset - n_train
    trainset, validationset = torch.utils.data.random_split(
            trainset,
            [n_train, n_val],
            generator=torch.Generator().manual_seed(seed))
    
    if device is not None:
        trainset = trainset.to(device)
        validationset = refset.to(device)
        testnset = testset.to(device)

    train_loader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True)
    validation_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=test_batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)
    test_loader = torch.utils.data.DataLoader(testset,
                                             batch_size=test_batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True)

    return train_loader, validation_loader, test_loader
Exemple #28
0
    def test_svi(self):
        transform = transforms.Compose([
            transforms.Resize(self.config.data.image_size),
            transforms.ToTensor()
        ])

        if self.config.data.dataset == 'CIFAR10':
            test_dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar10'), train=False, download=True,
                                   transform=transform)
        elif self.config.data.dataset == 'MNIST':
            test_dataset = MNIST(os.path.join(self.args.run, 'datasets', 'mnist'), train=False, download=True,
                                 transform=transform)
        elif self.config.data.dataset == 'CELEBA':
            dataset = ImageFolder(root=os.path.join(self.args.run, 'datasets', 'celeba'),
                                  transform=transforms.Compose([
                                      transforms.CenterCrop(140),
                                      transforms.Resize(self.config.data.image_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
            num_items = len(dataset)
            indices = list(range(num_items))
            random_state = np.random.get_state()
            np.random.seed(2019)
            np.random.shuffle(indices)
            np.random.set_state(random_state)
            test_indices = indices[int(0.8 * num_items):]
            test_dataset = Subset(dataset, test_indices)

        test_loader = DataLoader(test_dataset, batch_size=self.config.training.batch_size, shuffle=False,
                                 num_workers=2)

        self.config.input_dim = self.config.data.image_size ** 2 * self.config.data.channels

        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'), map_location=self.config.device)
        decoder = MLPDecoder(self.config).to(self.config.device)
        if self.config.training.algo == 'vae':
            encoder = MLPEncoder(self.config).to(self.config.device)
            encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
        elif self.config.training.algo == 'ssm':
            score = Score(self.config).to(self.config.device)
            imp_encoder = MLPImplicitEncoder(self.config).to(self.config.device)
            imp_encoder.load_state_dict(states[0])
            decoder.load_state_dict(states[1])
            score.load_state_dict(states[2])

        total_l = 0.
        total_n = 0
        recon_type = 'bernoulli' if self.config.data.dataset == 'MNIST' else 'gaussian'
        from models.gmm import Gaussian4SVI
        for batch, (X, y) in enumerate(test_loader):
            X = X.to(self.config.device)
            if self.config.data.dataset == 'CELEBA':
                X = X + (torch.rand_like(X) - 0.5) / 128.
            elif self.config.data.dataset == 'MNIST':
                eps = torch.rand_like(X)
                X = (eps <= X).float()

            gaussian = Gaussian4SVI(X.shape[0], self.config.model.z_dim).to(self.config.device)
            optimizer = optim.SGD(gaussian.parameters(), lr=0.01, momentum=0.5)
            lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.3)
            for i in range(300):
                lr_scheduler.step()
                loss = iwae(gaussian, decoder, X, type=recon_type, k=10, training=True)
                decoder.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print(i, loss.item())
            loss = iwae(gaussian, decoder, X, type=recon_type, k=10, training=False)
            total_l += loss.item() * X.shape[0]
            total_n += X.shape[0]
            print('mini-batch: {}, current iwae-10: {}, average iwae-10: {}'.format(batch + 1, loss.item(),
                                                                                    total_l / total_n))
Exemple #29
0
def _load_cifar10(transform: transforms) -> Tuple[CIFAR10, CIFAR10]:
    return CIFAR10(root='./data', train=True,  download=True, transform=transform), \
           CIFAR10(root='./data', train=False, download=True, transform=transform)
Exemple #30
0
    def test(self):
        # Load the score network
        states = torch.load(os.path.join(self.args.log, 'checkpoint.pth'),
                            map_location=self.config.device)
        scorenet = CondRefineNetDilated(self.config).to(self.config.device)
        scorenet = torch.nn.DataParallel(scorenet)
        scorenet.load_state_dict(states[0])
        scorenet.eval()

        # Grab some samples from the test set
        dataset = CIFAR10(os.path.join(self.args.run, 'datasets', 'cifar'),
                          train=False,
                          transform=transforms.ToTensor(),
                          download=True)
        dataloader = iter(DataLoader(dataset, batch_size=50, shuffle=False))
        x0, y0 = next(dataloader)
        x1, y1 = next(dataloader)

        self.write_images(x0, 'xgt.png')
        self.write_images(x1, 'ygt.png')

        mixed = (x0 + x1).cuda()

        self.write_images(mixed.cpu() / 2., 'mixed.png')

        x0 = nn.Parameter(torch.Tensor(50, 3, 32, 32).uniform_()).cuda()
        x1 = nn.Parameter(torch.Tensor(50, 3, 32, 32).uniform_()).cuda()

        recon = (x0 + x1 - mixed)**2

        step_lr = 0.00003

        # Noise amounts
        sigmas = np.array([
            1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,
            0.04641589, 0.02782559, 0.01668101, 0.01
        ])
        n_steps_each = 100

        for idx, sigma in enumerate(sigmas):
            lambda_recon = 1. / sigma**2
            labels = torch.ones(1, device=x0.device) * idx
            labels = labels.long()
            step_size = step_lr * (sigma / sigmas[-1])**2

            print('sigma = {}'.format(sigma))
            for step in range(n_steps_each):
                recon = ((x0 + x1 - mixed)**2).view(-1,
                                                    3 * 32 * 32).sum(1).mean()

                noise_x = torch.randn_like(x0) * np.sqrt(step_size * 2)
                noise_y = torch.randn_like(x1) * np.sqrt(step_size * 2)

                grad_x0 = scorenet(x0, labels).detach()
                grad_x1 = scorenet(x1, labels).detach()

                norm0 = np.linalg.norm(grad_x0.view(-1,
                                                    3 * 32 * 32).cpu().numpy(),
                                       axis=1).mean()
                norm1 = np.linalg.norm(grad_x1.view(-1,
                                                    3 * 32 * 32).cpu().numpy(),
                                       axis=1).mean()

                x0 += step_size * (grad_x0 - lambda_recon *
                                   (x0 + x1 - mixed)) + noise_x
                x1 += step_size * (grad_x1 - lambda_recon *
                                   (x0 + x1 - mixed)) + noise_y

            print(' recon: {}, |norm1|: {}, |norm2|: {}'.format(
                recon, norm0, norm1))

        # Write x0 and x1
        self.write_images(x0.detach().cpu(), 'x.png')
        self.write_images(x1.detach().cpu(), 'y.png')