示例#1
0
def conceptcombineeval(model_list, select_idx):
    dataset = CelebAHQOverfit()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=4)

    n = 64
    labels = []

    for six in select_idx:
        label_ix = np.eye(2)[six]
        label_batch = np.tile(label_ix[None, :], (n, 1))
        label = torch.Tensor(label_batch).cuda()
        labels.append(label)

    def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s,
                                              0.4 * s)
        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.2)
        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
        return color_distort

    color_transform = get_color_distortion()

    im_size = 128
    transform = transforms.Compose([
        transforms.RandomResizedCrop(im_size, scale=(0.4, 1.0)),
        transforms.RandomHorizontalFlip(), color_transform,
        transforms.ToTensor()
    ])

    gt_ims = []
    fake_ims = []

    for data, label in tqdm(dataloader):
        gt_ims.extend(
            list((data.numpy().transpose(
                (0, 2, 3, 1)) * 255).astype(np.uint8)))

        im = torch.rand(n, 3, 128, 128).cuda()
        im_noise = torch.randn_like(im).detach()
        # First get good initializations for sampling
        # for i in range(10):
        #     for i in range(20):
        #         im_noise.normal_()
        #         im = im + 0.001 * im_noise
        #         # im.requires_grad = True
        #         im.requires_grad_(requires_grad=True)
        #         energy = 0

        #         for model, label in zip(model_list, labels):
        #             energy = model.forward(im, label) +  energy

        #         # print("step: ", i, energy.mean())
        #         im_grad = torch.autograd.grad([energy.sum()], [im])[0]

        #         im = im - FLAGS.step_lr *  im_grad
        #         im = im.detach()

        #         im = torch.clamp(im, 0, 1)

        #     im = im.detach().cpu().numpy().transpose((0, 2, 3, 1))
        #     im = (im * 255).astype(np.uint8)

        #     ims = []
        #     for i in range(im.shape[0]):
        #         im_i = np.array(transform(Image.fromarray(np.array(im[i]))))
        #         ims.append(im_i)

        #     im = torch.Tensor(np.array(ims)).cuda()

        # Then refine the images

        for i in range(FLAGS.num_steps):
            im_noise.normal_()
            im = im + 0.001 * im_noise
            # im.requires_grad = True
            im.requires_grad_(requires_grad=True)
            energy = 0

            for model, label in zip(model_list, labels):
                energy = model.forward(im, label) + energy

            print("step: ", i, energy.mean())
            im_grad = torch.autograd.grad([energy.sum()], [im])[0]

            im = im - FLAGS.step_lr * im_grad
            im = im.detach()

            im = torch.clamp(im, 0, 1)

        im = im.detach().cpu()
        fake_ims.extend(
            list((im.numpy().transpose((0, 2, 3, 1)) * 255).astype(np.uint8)))
        if len(gt_ims) > 10000:
            break

    get_fid_score(gt_ims, fake_ims)
    fake_ims = np.array(fake_ims)
    fake_ims_flat = fake_ims.reshape(fake_ims.shape[0], -1)
    std_im = np.std(fake_ims, axis=0).mean()
    print("standard deviation of image", std_im)
    import pdb
    pdb.set_trace()
    print("here")
def conceptcombineeval(model_list, select_idx):
    dataset = ImageNet()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=4)

    n = 64
    labels = []

    for six in select_idx:
        six = np.random.permutation(1000)[:n]
        print(six)
        label_batch = np.eye(1000)[six]
        # label_ix = np.eye(2)[six]
        # label_batch = np.tile(label_ix[None, :], (n, 1))
        label = torch.Tensor(label_batch).cuda()
        labels.append(label)

    def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s,
                                              0.4 * s)
        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.2)
        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
        return color_distort

    color_transform = get_color_distortion(0.5)

    im_size = 128
    transform = transforms.Compose([
        transforms.RandomResizedCrop(im_size, scale=(0.3, 1.0)),
        transforms.RandomHorizontalFlip(), color_transform,
        transforms.ToTensor()
    ])

    gt_ims = []
    fake_ims = []

    label_embed = torch.eye(1000).cuda()
    im = None

    for _, data, label in tqdm(dataloader):
        print(label)
        gt_ims.extend(list((data.numpy() * 255).astype(np.uint8)))

        if im is None:
            im = torch.rand(n, 3, 128, 128).cuda()

        im_noise = torch.randn_like(im).detach()
        # First get good initializations for sampling
        for i in range(5):
            for i in range(60):
                label = torch.randperm(1000).to(im.device)[:n]
                label = label_embed[label]
                im_noise.normal_()
                im = im + 0.001 * im_noise
                # im.requires_grad = True
                im.requires_grad_(requires_grad=True)
                energy = 0

                for model, label in zip(model_list, labels):
                    energy = model.forward(im, label) + energy

                # print("step: ", i, energy.mean())
                im_grad = torch.autograd.grad([energy.sum()], [im])[0]

                im = im - FLAGS.step_lr * im_grad
                im = im.detach()

                im = torch.clamp(im, 0, 1)

            im = im.detach().cpu().numpy().transpose((0, 2, 3, 1))
            im = (im * 255).astype(np.uint8)

            ims = []
            for i in range(im.shape[0]):
                im_i = np.array(transform(Image.fromarray(np.array(im[i]))))
                ims.append(im_i)

            im = torch.Tensor(np.array(ims)).cuda()

        # Then refine the images

        for i in range(FLAGS.num_steps):
            im_noise.normal_()
            im = im + 0.001 * im_noise
            # im.requires_grad = True
            im.requires_grad_(requires_grad=True)
            energy = 0

            label = torch.randperm(1000).to(im.device)[:n]
            label = label_embed[label]

            for model, label in zip(model_list, labels):
                energy = model.forward(im, label) + energy

            print("step: ", i, energy.mean())
            im_grad = torch.autograd.grad([energy.sum()], [im])[0]

            im = im - FLAGS.step_lr * im_grad
            im = im.detach()

            im = torch.clamp(im, 0, 1)

        im_cpu = im.detach().cpu()
        ims = list((im_cpu.numpy().transpose(
            (0, 2, 3, 1)) * 255).astype(np.uint8))

        fake_ims.extend(ims)
        if len(gt_ims) > 50000:
            break

    splits = max(1, len(fake_ims) // 5000)
    score, std = get_inception_score(fake_ims, splits=splits)
    print("inception score {}, with std {} ".format(score, std))
    get_fid_score(gt_ims, fake_ims)
    import pdb
    pdb.set_trace()
    print("here")
示例#3
0
def compute_inception(sess, target_vars):
    X_START = target_vars['X_START']
    Y_GT = target_vars['Y_GT']
    X_finals = target_vars['X_finals']
    NOISE_SCALE = target_vars['NOISE_SCALE']
    energy_noise = target_vars['energy_noise']

    size = FLAGS.im_number
    num_steps = size // 1000

    images = []
    test_ims = []
    test_images = []


    if FLAGS.dataset == "cifar10":
        test_dataset = Cifar10(full=True, noise=False)
    elif FLAGS.dataset == "celeba":
        dataset = CelebA()
    elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
        test_dataset = Imagenet(train=False)

    if FLAGS.dataset != "imagenetfull":
        test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
    else:
        test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)

    for data_corrupt, data, label_gt in tqdm(test_dataloader):
        data = data.numpy()
        test_ims.extend(list(rescale_im(data)))

        if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
            test_ims = test_ims[:60000]
            break


    # n = min(len(images), len(test_ims))
    print(len(test_ims))
    # fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
    # print("Base FID of score {}".format(fid))

    if FLAGS.dataset == "cifar10":
        classes = 10
    else:
        classes = 1000

    if FLAGS.dataset == "imagenetfull":
        n = 128
    else:
        n = 32

    for j in range(num_steps):
        itr = int(1000 / 500 * FLAGS.repeat_scale)
        data_buffer = InceptionReplayBuffer(1000)
        curr_index = 0

        identity = np.eye(classes)

        test_steps = range(300, itr, 20)

        for i in tqdm(range(itr)):
            model_index = curr_index % len(X_finals)
            x_final = X_finals[model_index]

            noise_scale = [1]
            if len(data_buffer) < 1000:
                x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
                label = np.random.randint(0, classes, (FLAGS.batch_size))
                label = identity[label]
                x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
                data_buffer.add(x_new, label)
            else:
                (x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
                keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
                label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
                label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
                label_corrupt = identity[label_corrupt]
                x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))

                if i < itr - FLAGS.nomix:
                    x_init[keep_mask] = x_init_corrupt[keep_mask]
                    label[label_keep_mask] = label_corrupt[label_keep_mask]
                # else:
                #     noise_scale = [0.7]

                x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
                data_buffer.set_elms(idx, x_new, label)

            curr_index += 1

        ims = np.array(data_buffer._storage[:1000])
        ims = rescale_im(ims)
        test_images.extend(list(ims))

    saveim = osp.join(FLAGS.logdir, FLAGS.exp, "test{}.png".format(FLAGS.resume_iter))
    row = 15
    col = 20
    ims = ims[:row * col]
    if FLAGS.dataset != "imagenetfull":
        im_panel = ims.reshape((row, col, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((32*row, 32*col, 3))
    else:
        im_panel = ims.reshape((row, col, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((128*row, 128*col, 3))
    imsave(saveim, im_panel)

    splits = max(1, len(test_images) // 5000)
    score, std = get_inception_score(test_images, splits=splits)
    print("Inception score of {} with std of {}".format(score, std))

    # FID score
    # n = min(len(images), len(test_ims))
    fid = get_fid_score(test_images, test_ims)
    print("FID of score {}".format(fid))
示例#4
0
def compute_inception(model):
    size = FLAGS.im_number
    num_steps = size // 1000

    images = []
    test_ims = []

    if FLAGS.dataset == "cifar10":
        test_dataset = Cifar10(FLAGS)
    elif FLAGS.dataset == "celeba":
        test_dataset = CelebAHQ()
    elif FLAGS.dataset == "mnist":
        test_dataset = Mnist(train=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 drop_last=False)

    if FLAGS.dataset == "cifar10":
        for data_corrupt, data, label_gt in tqdm(test_dataloader):
            data = data.numpy()
            test_ims.extend(list(rescale_im(data)))

            if len(test_ims) > 10000:
                break
    elif FLAGS.dataset == "mnist":
        for data_corrupt, data, label_gt in tqdm(test_dataloader):
            data = data.numpy()
            test_ims.extend(list(np.tile(rescale_im(data), (1, 1, 3))))

            if len(test_ims) > 10000:
                break

    test_ims = test_ims[:10000]

    classes = 10

    print(FLAGS.batch_size)
    data_buffer = None

    for j in range(num_steps):
        itr = int(1000 / 500 * FLAGS.repeat_scale)

        if data_buffer is None:
            data_buffer = InceptionReplayBuffer(1000)

        curr_index = 0

        identity = np.eye(classes)

        if FLAGS.dataset == "celeba":
            n = 128
            c = 3
        elif FLAGS.dataset == "mnist":
            n = 28
            c = 1
        else:
            n = 32
            c = 3

        for i in tqdm(range(itr)):
            noise_scale = [1]
            if len(data_buffer) < 1000:
                x_init = np.random.uniform(0, 1, (FLAGS.batch_size, c, n, n))
                label = np.random.randint(0, classes, (FLAGS.batch_size))

                x_init = torch.Tensor(x_init).cuda()
                label = identity[label]
                label = torch.Tensor(label).cuda()

                x_new, _ = gen_image(label, FLAGS, model, x_init,
                                     FLAGS.num_steps)
                x_new = x_new.detach().cpu().numpy()
                label = label.detach().cpu().numpy()
                data_buffer.add(x_new, label)
            else:
                if i < itr - FLAGS.nomix:
                    (x_init, label), idx = data_buffer.sample(
                        FLAGS.batch_size, transform=FLAGS.transform)
                else:
                    if FLAGS.dataset == "celeba":
                        n = 20
                    else:
                        n = 2

                    ix = i % n
                    # for i in range(n):
                    start_idx = (1000 // n) * ix
                    end_idx = (1000 // n) * (ix + 1)
                    (x_init, label) = data_buffer._encode_sample(
                        list(range(start_idx, end_idx)), transform=False)
                    idx = list(range(start_idx, end_idx))

                x_init = torch.Tensor(x_init).cuda()
                label = torch.Tensor(label).cuda()
                x_new, energy = gen_image(label, FLAGS, model, x_init,
                                          FLAGS.num_steps)
                energy = energy.cpu().detach().numpy()
                x_new = x_new.cpu().detach().numpy()
                label = label.cpu().detach().numpy()
                data_buffer.set_elms(idx, x_new, label)

                if FLAGS.im_number != 50000:
                    print(np.mean(energy), np.std(energy))

            curr_index += 1

        ims = np.array(data_buffer._storage[:1000])
        ims = rescale_im(ims).transpose((0, 2, 3, 1))

        if FLAGS.dataset == "mnist":
            ims = np.tile(ims, (1, 1, 1, 3))

        images.extend(list(ims))

    random.shuffle(images)
    saveim = osp.join('sandbox_cachedir', FLAGS.exp,
                      "test{}.png".format(FLAGS.idx))

    if FLAGS.dataset == "cifar10":
        rix = np.random.permutation(1000)[:100]
        ims = ims[rix]
        im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((320, 320, 3))
        imsave(saveim, im_panel)

        print("Saved image!!!!")
        splits = max(1, len(images) // 5000)
        score, std = get_inception_score(images, splits=splits)
        print("Inception score of {} with std of {}".format(score, std))

        # FID score
        n = min(len(images), len(test_ims))
        fid = get_fid_score(images, test_ims)
        print("FID of score {}".format(fid))

    elif FLAGS.dataset == "mnist":
        # ims = ims[:100]
        # im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
        # imsave(saveim, im_panel)

        ims = ims[:100]
        im_panel = ims.reshape((10, 10, 28, 28, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((280, 280, 3))
        imsave(saveim, im_panel)

        print("Saved image!!!!")
        splits = max(1, len(images) // 5000)
        # score, std = get_inception_score(images, splits=splits)
        # print("Inception score of {} with std of {}".format(score, std))

        # FID score
        n = min(len(images), len(test_ims))
        fid = get_fid_score(images, test_ims)
        print("FID of score {}".format(fid))

    elif FLAGS.dataset == "celeba":

        ims = ims[:25]
        im_panel = ims.reshape((5, 5, 128, 128, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((5 * 128, 5 * 128, 3))
        imsave(saveim, im_panel)