def load_fid(self, dataname):
     # if dataname == 'cifar10':
     #     stats_path = 'metrics/stats/CIFAR10_inception_moments.npz'
     # elif dataname == 'lsun-bedroom':
     #     stats_path = 'metrics/stats/LSUN-bedroom_inception_moments.npz'
     # elif dataname == 'MNIST':
     stats_path = 'metrics/stats/%s_inception_moments.npz' % dataname
     print('Load stats of %s' % dataname)
     f = np.load(stats_path)
     self.mu_real, self.sigma_real = f['mu'][:], f['sigma'][:]
     self.mu_real = torch.tensor(self.mu_real).float().cuda()
     self.sigma_real = torch.tensor(self.sigma_real).float().cuda()
     f.close()
     self.net = load_inception_net(parallel=False)
 def __init__(self,
              G,
              z_dim,
              model_dir,
              log_path,
              device,
              batchsize=100,
              dim=1):
     self.is_flag = False
     self.fid_flag = False
     self.log_path = log_path
     self.device = device
     self.G = G
     self.z_dim = z_dim
     self.dim = dim
     self.batchsize = batchsize
     self.model_dir = model_dir
     self.init_writer()
     self.net = load_inception_net(parallel=False)
示例#3
0
def run(config):
    # Get loader
    config['drop_last'] = False
    loader = get_data(dataname=config['dataset'], path=config['data_path'])

    # Load inception net
    net = load_inception_net(parallel=config['parallel'])
    pool, logits = [], []
    device = torch.device('cuda:0')
    print(device)
    for i, (x, y) in enumerate(tqdm(loader)):
        x = x.to(device)
        with torch.no_grad():
            pool_val, logits_val = net(x)
            pool += [np.asarray(pool_val.cpu())]
            logits += [np.asarray(F.softmax(logits_val, 1).cpu())]

    pool, logits = [np.concatenate(item, 0) for item in [pool, logits]]
    # uncomment to save pool, logits, and labels to disk
    # print('Saving pool, logits, and labels to disk...')
    # np.savez(config['dataset']+'_inception_activations.npz',
    #           {'pool': pool, 'logits': logits, 'labels': labels})
    # Calculate inception metrics and report them
    print('Calculating inception metrics...')
    IS_mean, IS_std = calculate_inception_score(logits)
    print('Training data from dataset %s has IS of %5.5f +/- %5.5f' %
          (config['dataset'], IS_mean, IS_std))
    # Prepare mu and sigma, save to disk. Remove "hdf5" by default
    # (the FID code also knows to strip "hdf5")
    print('Calculating means and covariances...')
    mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
    print('Saving calculated means and covariances to disk...')
    np.savez('metrics/stats/' + config['dataset'] + '_inception_moments.npz',
             **{
                 'mu': mu,
                 'sigma': sigma
             })