Exemple #1
0
def gen_batch(dataset,
              N=None,
              loader=None,
              shuffle=False,
              seed=None,
              ret_loader=False):
    if seed is not None:
        util.set_seed(seed)

    if loader is None:
        assert N is not None
        loader = iter(new_loader(dataset, batch_size=N, shuffle=shuffle))

    try:
        batch = util.to(next(loader), 'cuda')
        B = batch.size(0)
    except StopIteration:
        pass
    else:
        if N is None or B == N:
            if ret_loader:
                return batch[0], loader
            return batch[0]

    loader = iter(new_loader(dataset, batch_size=N, shuffle=shuffle))
    batch = util.to(next(loader), 'cuda')

    if ret_loader:
        return batch[0], loader

    return batch[0]
Exemple #2
0
def rec_generate(S, **unused):

    A = S.A

    Q = S.full_q
    assert Q is not None, 'no latent space'

    if isinstance(Q, distrib.Distribution):
        Q = Q.loc

    model = S.model

    util.set_seed(0)

    def generate(N):

        idx = torch.randperm(len(Q))[:N]

        q = Q[idx].to(A.device)

        with torch.no_grad():
            gen = model.decode(q).detach()

        return gen

    return generate
Exemple #3
0
def prior_generate(S, **unused):

    model = S.model

    util.set_seed(0)

    def generate(N):
        with torch.no_grad():
            return model.generate(N)

    return generate
Exemple #4
0
def main(argv=None):

    if sys.gettrace() is not None:
        print('in pycharm')

        # c = 'n/mpi3d'
        # c = 'n/celeba'
        # name = 'mpi3d_stats_fid.pkl'
        #
        # argv = '--pbar --save-path /is/ei/fleeb/workspace/local_data/mpi3d/{} --config {}'.format(name, c).split(' ')

        # argv.extend(['--n-samples', '100'])

        # return 1 # no accidental runs in debugger

    else:
        print('not in pycharm')

    parser = get_parser()
    args = parser.parse_args(argv)

    root = '/is/ei/fleeb/workspace/local_data/fid_stats'

    # names = ['box', 'cyl', 'sph', 'cap', 'box-cyl', 'box-cyl-sph']

    names = ['MsPacman', 'SpaceInvaders', 'Asterix', 'Seaquest']

    # unique = [trn.get_config('n/t/box') for _ in range(len(names))]

    unique = [trn.get_config('n/atari') for _ in range(len(names))]

    for u in unique:
        u.dataset.train = None
        u.dataset.device = 'cpu'

    # unique[0].dataset.counts = [-1, 0, 0, 0]
    # unique[1].dataset.counts = [0, -1, 0, 0]
    # unique[2].dataset.counts = [0, 0, -1, 0]
    # unique[3].dataset.counts = [0, 0, 0, -1]
    # unique[4].dataset.counts = [-1, -1, 0, 0]
    # unique[5].dataset.counts = [-1, -1, -1, 0]

    args.pbar = True

    for name, u in zip(names, unique):

        u.dataset.game = name

        print('Running {}'.format(name))

        save_path = os.path.join(root, '{}_fid_stats.pkl'.format(name))

        print(save_path)

        dataset = trn.get_dataset(info=u.dataset)

        print('Loaded dataset: {}'.format(len(dataset)))
        print(dataset)

        gen = Dataset_Generator(dataset)

        util.set_seed(args.seed)
        print('Set seed: {}'.format(args.seed))

        m, s = compute_inception_stat(gen,
                                      batch_size=128,
                                      n_samples=50000,
                                      pbar=tqdm)

        print(m.shape, s.shape)

        pkl.dump({'m': m, 'sigma': s}, open(save_path, 'wb'))
        print('Saved stats to {}'.format(save_path))

        print('-' * 50)

    print('All done.')
Exemple #5
0
def _full_analyze(run, save_dir):
	# def _full_analyze(run):

	S = run.reset()

	dname = run.meta.dataset

	if 'box' in dname or 'nocap' in dname:
		dname = '3dshapes'

	# check for existing results

	save_path = os.path.join(save_dir, run.name)
	util.create_dir(save_path)

	results_path = os.path.join(save_path, 'results.pth.tar')
	evals_path = os.path.join(save_path, 'evals.pth.tar')

	if os.path.isfile(results_path):
		print('Found existing results: {}'.format(results_path))
		results = torch.load(results_path)
		print(results.keys())
	else:
		results = {}

	if os.path.isfile(evals_path):
		print('Found existing evals: {}'.format(evals_path))
		evals = torch.load(evals_path)
		print(evals.keys())
	else:
		evals = {}

	# check for completion
	if 'fid' in results and ('disent' in results or dname != '3dshapes') and 'H' in results:
		print('Skipping {}, all analysis is already done'.format(run.name))
		# raise Exception()  # testing
		return

	# build dataset

	print('loading dataset {} for {}'.format(run.meta.dataset, run.name))

	eval_disentanglement_metrics = dname == '3dshapes'

	if dname == '3dshapes':

		C = trn.get_config('n/3dshapes')
		fid_stats_ref_path = '3dshapes/3dshapes_stats_fid.pkl'
		#     batch_size = 128

		pass
	elif dname in 'celeba':
		C = trn.get_config('n/celeba')
		fid_stats_ref_path = 'celeba/celeba_stats_fid.pkl'
		fid_stats = ''
	#     batch_size = 32
	elif dname == 'atari' or dname in {'spaceinv', 'pacman', 'seaquest', 'asterix'}:
		C = trn.get_config('n/atari')
		#     batch_size = 32

		C.dataset.game = run.config.dataset.game
		print('using {} game'.format(C.dataset.game))
		fid_stats_ref_path = 'fid_stats/{}_fid_stats.pkl'.format(C.dataset.game)
	# get game

	elif dname == 'mpi3d':
		C = trn.get_config('n/mpi3d')

		C.dataset.category = run.config.dataset.category
		print('using {} cat'.format(C.dataset.category))
		fid_stats_ref_path = 'mpi3d/mpi3d_{}_stats_fid.pkl'.format(C.dataset.category)

	#     batch_size = 128

	# get category

	else:
		raise Exception('{} not found'.format(dname))

	batch_size = 128

	C.dataset.device = 'cpu'

	C.dataset.train = False

	if 'val_split' in C.dataset:
		del C.dataset.val_split

	print('loading model {}'.format(run.ckpt_path))

	S.A = trn.get_config()

	if run.meta.dataset in {'celeba', 'atari', 'spaceinv', 'pacman', 'seaquest', 'asterix'}:
		din, dout = (3, 128, 128), (3, 128, 128)
	else:
		din, dout = (3, 64, 64), (3, 64, 64)

	S.A.din, S.A.dout = din, dout

	with redirect_stdout(open('/dev/null', 'w')):
		run.load(fast=True)
	#     run.load(fast=True)

	model = S.model

	print('model loaded')


	if dname == '3dshapes':  # disentanglement metrics

		if 'disent' not in results:
			results['disent'] = {}

		print('Computing Disentanglement metrics')

		results['disent'] = compute_all_disentanglement(model, disent=results['disent'])
		torch.save(results, results_path)
		print('Saved results to {}'.format(results_path))
		evals['disent'] = results['disent']

	datasets, = trn.load(A=C, get_data=get_data, get_model=None, mode='test')
	dataset = datasets[0]

	S.dataset = dataset
	S.dname = dname

	S.batch_size = batch_size

	print('dataset {} loaded: {}'.format(run.meta.dataset, len(dataset)))

	# run

	# rec error

	if 'L' not in results:
		print('Computing rec loss')

		dataset = S.dataset
		batch_size = S.batch_size

		util.set_seed(1)
		loader = new_loader(dataset, batch_size=batch_size, shuffle=True)

		loader = tqdm(loader)
		loader.set_description('Evaluating rec error')

		criterion = nn.BCELoss(reduction='none')

		L = []
		Q = None
		R = None
		O = None

		with torch.no_grad():

			for batch in loader:

				batch = util.to(batch, 'cuda')
				X = batch[0]
				B = X.size(0)

				rec = gen_target(model, X=X, hybrid=False, ret_q=Q is None)
				if Q is None:
					rec, Q = rec
					Q = Q.cpu()
				elif R is None:
					R = rec.cpu()
					O = X.cpu()

				loss = criterion(rec, X).view(B, -1).sum(-1)
				L.append(loss)

				# if R is not None:
				# 	break  # TESTING

			util.set_seed(2)
			G = model.generate(len(R)).cpu()
			util.set_seed(2)
			H = gen_target(model, Q=Q.cuda(), hybrid=True).cpu()

			L = torch.cat(L)

		del loader

		results.update({
			'O': O,  # original images
			'R': R,  # reconstructed images
			'L': L,  # reconstruction loss
			'Q': Q,  # latent vectors
			'G': G,  # generated samples using prior
			'H': H,  # generated samples using hybridization (drop-in, prob=1)
			'key': {
				'O': 'original images',
				'R': 'reconstructed images',
				'L': 'reconstruction error of each sample in the test set',
				'Q': 'latent vectors',
				'G': 'images generated from the prior',
				'H': 'images generated using hybridization (dropin, prob=1)',
			}
		})

	# FID score

	if 'fid' not in results:
		results['fid'] = {}
	if 'scores' not in results['fid']:
		results['fid']['scores'] = {}
		results['fid']['stats'] = {}

	if len(results['fid']['scores']) < 3:
		print('Computing FID scores')

		with torch.no_grad():
			results['fid'] = compute_all_fid_scores(model, dataset, fid_stats_ref_path, fid=results['fid'])

		torch.save(results, results_path)
		print('Saved results to {}'.format(results_path))

	if 'fid' not in evals:
		evals['fid'] = results['fid']['scores']

	print('Run {} complete'.format(run.name))

	run.state.evals = evals
	run.state.results = results

	run.save(save_dir=save_dir, overwrite=True, )

	run.reset()
Exemple #6
0
def compute_all_fid_scores(model, dataset, fid_stats_ref_path, fid=None):
    if fid is None:
        fid = {'scores': {}, 'stats': {}}

    path = os.path.join(os.environ["FOUNDATION_DATA_DIR"], fid_stats_ref_path)
    f = pickle.load(open(path, 'rb'))
    ref_stats = f['m'][:], f['sigma'][:]

    inception = load_inception_model(dim=2048, device='cuda')

    n_samples = 50000
    # n_samples = 100 # testing

    # rec
    name = 'rec'
    if name not in fid['scores']:
        util.set_seed(0)
        loader = None

        def _generate(N):
            nonlocal loader
            X, loader = gen_batch(dataset,
                                  loader=loader,
                                  shuffle=True,
                                  N=N,
                                  ret_loader=True)
            return gen_target(model, X=X, hybrid=False)

        stats = compute_inception_stat(_generate,
                                       inception=inception,
                                       pbar=tqdm,
                                       n_samples=n_samples)

        fid['scores'][name] = compute_frechet_distance(*stats, *ref_stats)
        fid['stats'][name] = stats
    print('FID-rec: {:.2f}'.format(fid['scores'][name]))

    # hyb
    name = 'hyb'
    if name not in fid['scores']:
        util.set_seed(0)
        loader = None

        def _generate(N):
            nonlocal loader
            X, loader = gen_batch(dataset,
                                  loader=loader,
                                  shuffle=True,
                                  N=N,
                                  ret_loader=True)
            return gen_target(model, X=X, hybrid=True)

        stats = compute_inception_stat(_generate,
                                       inception=inception,
                                       pbar=tqdm,
                                       n_samples=n_samples)

        fid['scores'][name] = compute_frechet_distance(*stats, *ref_stats)
        fid['stats'][name] = stats
    print('FID-hybrid: {:.2f}'.format(fid['scores'][name]))

    # prior
    name = 'prior'
    if name not in fid['scores']:
        util.set_seed(0)

        def _generate(N):
            return gen_prior(model, N)

        stats = compute_inception_stat(_generate,
                                       inception=inception,
                                       pbar=tqdm,
                                       n_samples=n_samples)

        fid['scores'][name] = compute_frechet_distance(*stats, *ref_stats)
        fid['stats'][name] = stats
    print('FID-prior: {:.2f}'.format(fid['scores'][name]))

    return fid
Exemple #7
0
def run_model(S, pbar=None, **unused):
    A = S.A
    dataset = S.dataset
    model = S.model

    assert False, 'unused'

    if 'loader' not in S:
        # DataLoader

        A.dataset.batch_size = 16

        util.set_seed(0)
        loader = train.get_loaders(
            dataset,
            batch_size=A.dataset.batch_size,
            num_workers=A.num_workers,
            shuffle=True,
            drop_last=False,
        )
        util.set_seed(0)
        loader = iter(loader)

        S.loader = loader

    common_Ws = {64: 8, 32: 4, 16: 4, 9: 3, 8: 2, 4: 2}
    border, between = 0.02, 0.01
    img_W = common_Ws[A.dataset.batch_size]
    S.img_W = img_W
    S.border, S.between = border, between

    if 'X' not in S:
        # Batch

        batch = next(loader)
        batch = util.to(batch, A.device)

        S.batch = batch
        S.X = batch[0]

    X = S.X

    with torch.no_grad():

        if model.enc is None:
            q = model.sample_prior(X.size(0))
        else:
            q = model.encode(X)
        qdis = None
        qmle = q
        if isinstance(q, distrib.Distribution):
            qdis = q
            q = q.rsample()
            qmle = qdis.loc
        rec = model.decode(q)
        # vrec = model.disc(rec) if model.disc is not None else None

        p = model.sample_prior(X.size(0))
        gen = model.decode(p)

        h = util.shuffle_dim(q)
        hyb = model.decode(h)

    S.rec = rec
    S.gen = gen
    S.hyb = hyb

    S.q = q
    S.p = p
    S.h = h

    S.qdis = qdis
    S.qmle = qmle

    batch_size = 128  # number of samples to get distribution
    util.set_seed(0)
    int_batch = next(
        iter(
            train.get_loaders(
                dataset,
                batch_size=batch_size,
                num_workers=A.num_workers,
                shuffle=True,
                drop_last=False,
            )))
    with torch.no_grad():
        int_batch = util.to(int_batch, A.device)
        int_X, = int_batch
        if model.enc is None:
            int_q = model.sample_prior(int_X.size(0))
        else:
            int_q = model.encode(int_X)
        dis_int_q = None
        if isinstance(int_q, distrib.Distribution):
            #         int_q = int_q.rsample()
            dis_int_q = int_q
            int_q = int_q.loc
    del int_batch
    del int_X

    S.int_q = int_q
    S.dis_int_q = dis_int_q

    latent_dim = A.model.latent_dim
    iH, iW = X.shape[-2:]

    rows = 4
    steps = 60

    if 'bounds' in S:
        bounds = S.bounds
    else:
        # bounds = -2,2
        bounds = None
    # dlayout = rows, latent_dim // rows
    dlayout = util.calc_tiling(latent_dim)
    outs = []

    all_diffs = []
    inds = [0, 2, 3]
    inds = np.arange(len(q))
    save_inds = [0, 1, 2, 3]
    # save_inds =  []

    saved_walks = []

    for idx in inds:

        walks = []
        for dim in range(latent_dim):

            dev = int_q[:, dim].std()
            if bounds is None:
                deltas = torch.linspace(int_q[:, dim].min(),
                                        int_q[:, dim].max(), steps)
            else:
                deltas = torch.linspace(bounds[0], bounds[1], steps)
            vecs = torch.stack([int_q[idx]] * steps)
            vecs[:, dim] = deltas

            with torch.no_grad():
                walks.append(model.decode(vecs).cpu())

        walks = torch.stack(walks, 1)
        chn = walks.shape[2]

        dsteps = 10
        diffs = (walks[dsteps:] - walks[:-dsteps]
                 ).abs()  # .view(steps-dsteps, latent_dim, chn, 64*64)
        #     diffs /= (walks[dsteps:]).abs()
        # diffs = diffs.clamp(min=1e-10,max=1).abs()
        diffs = diffs.view(steps - dsteps, latent_dim, chn * iH * iW).mean(-1)
        #     diffs = 1 - diffs.mean(-1)
        #     print(diffs.shape)
        #     diffs *= 2
        all_diffs.append(diffs.mean(0))
        #     print(all_diffs[-1])

        if idx in save_inds:
            # save_dir = S.save_dir

            walks_full = walks.view(steps, dlayout[0], dlayout[1], chn, iH, iW) \
             .permute(0, 1, 4, 2, 5, 3).contiguous().view(steps, iH * dlayout[0], iW * dlayout[1], chn).squeeze()
            images = []
            for img in walks_full.cpu().numpy():
                images.append((img * 255).astype(np.uint8))

            saved_walks.append(images)

            # imageio.mimsave(os.path.join(save_dir, 'walks-idx{}.gif'.format(idx, dim)), images)
            #
            # with open(os.path.join(save_dir, 'walks-idx{}.gif'.format(idx, dim)), 'rb') as f:
            # 	outs.append(display.Image(data=f.read(), format='gif'))
            del walks_full

    all_diffs = torch.stack(all_diffs)

    S.all_diffs = all_diffs
    S.saved_walks = saved_walks

    dataset = S.dataset

    full_q = None

    if model.enc is not None:

        N = min(1000, len(dataset))

        print('Using {} samples'.format(N))

        assert (N // 4) * 4 == N, 'invalid num: {}'.format(N)

        loader = train.get_loaders(
            dataset,
            batch_size=N // 4,
            num_workers=A.num_workers,
            shuffle=True,
            drop_last=False,
        )

        util.set_seed(0)

        if pbar is not None:
            loader = pbar(loader, total=len(loader))
            loader.set_description('Collecting latent vectors')

        full_q = []

        for batch in loader:
            batch = util.to(batch, A.device)
            X, = batch

            with torch.no_grad():

                q = model.encode(X)
                if isinstance(q, distrib.Distribution):
                    q = torch.stack([q.loc, q.scale], 1)
                full_q.append(q.cpu())

        if pbar is not None:
            loader.close()

        if len(full_q):
            full_q = torch.cat(full_q)

            # print(full_q.shape)

            if len(full_q.shape) > 2:
                full_q = distrib.Normal(loc=full_q[:, 0], scale=full_q[:, 1])

        else:
            full_q = None

    S.full_q = full_q

    if full_q is not None:
        if 'results' not in S:
            S.results = {}
        S.results['val_Q'] = full_q

        print('Storing {} latent vectors'.format(
            len(full_q if not isinstance(full_q, distrib.Distribution) else
                full_q.loc)))