Ejemplo n.º 1
0
    def get_inception_score():
        all_samples = []
        samples = torch.randn(N_SAMPLES, N_LATENT)
        for i in xrange(0, N_SAMPLES, 100):
            batch_samples = samples[i:i+100].cuda(0)
            all_samples.append(gen(batch_samples).cpu().data.numpy())

        all_samples = np.concatenate(all_samples, axis=0)
        return inception_score(torch.from_numpy(all_samples), resize=True, cuda=True)
Ejemplo n.º 2
0
def is_ch1():
    class IgnoreLabelDataset(torch.utils.data.Dataset):
        def __init__(self, orig):
            self.orig = orig

        def __getitem__(self, index):
            return self.orig[index]

        def __len__(self):
            return len(self.orig)

    '''
    import torchvision.datasets as dset
    import torchvision.transforms as transforms

    cifar = dset.CIFAR10(root='data/', download=True,
                             transform=transforms.Compose([
                                 transforms.Scale(32),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ])
    )

    IgnoreLabelDataset(cifar)
    '''

    p = 0.5
    bernoulli = torch.distributions.Bernoulli(torch.tensor([p]))
    netG = GeneratorDCGAN_cifar(z_dim=100, model_dim=64, num_classes=10).cuda()
    netG.load_state_dict(
        torch.load('results/cifar_10/main/d500_i40000_10/netGS_80000.pth'))

    sample_list = []
    for class_id in range(10):
        noise = bernoulli.sample((5000, 100)).view(5000, 100).cuda()
        label = torch.full((5000, ), class_id, dtype=torch.long).cuda()
        sample = netG(noise, label)
        sample = sample.view(5000, 32, 32)
        sample = sample.cpu().data.numpy()
        sample_list.append(sample)
    samples = np.transpose(np.array(sample_list), [1, 0, 2, 3])
    samples = np.reshape(samples, [50000, 1, 32, 32])
    samples = np.repeat(samples, 3, axis=1)

    print("Calculating Inception Score...")
    print(
        inception_score(IgnoreLabelDataset(samples),
                        cuda=True,
                        batch_size=32,
                        resize=True,
                        splits=10))