Ejemplo n.º 1
0
def main(argv):
	parser = ArgumentParser(argv[0], description=__doc__)
	parser.add_argument('image',                    type=str)
	parser.add_argument('model',                    type=str)
	parser.add_argument('--init',             '-I', type=str, default=None)
	parser.add_argument('--index',            '-x', type=int, default=0,
		help='Determines which image is used when whole dataset is given instead of image.')
	parser.add_argument('--fill_region',      '-f', type=int, default=71)
	parser.add_argument('--outer_patch_size', '-p', type=int, default=19)
	parser.add_argument('--inner_patch_size', '-i', type=int, default=5)
	parser.add_argument('--stride',           '-s', type=int, default=3)
	parser.add_argument('--candidates',       '-C', type=int, default=5,
		help='The best initialization is taken out of this many initializations.')
	parser.add_argument('--num_epochs',       '-e', type=int, default=1000)
	parser.add_argument('--method',           '-m', type=str, default='SAMPLE', choices=['SAMPLE', 'MAP'])
	parser.add_argument('--step_width',       '-l', type=float, default=100.)
	parser.add_argument('--output',           '-o', type=str, default='results/inpainting/')
	parser.add_argument('--flip',             '-F', type=int, default=0,
		help='If > 0, assume horizontal symmetry. If > 1, assume vertical symmetry.')

	args = parser.parse_args(argv[1:])


	### DATA

	# load image
	if args.image.lower()[-4:] in ['.gif', '.png', '.jpg', 'jpeg']:
		image = imread(args.image)[None]
		vmin, vmax = 0, 255
	else:
		image = loadmat(args.image)['data'][[args.index]]
		vmin, vmax = image.min(), image.max()
	if image.ndim < 4:
		image = image[:, :, :, None]

	image = asarray(image, dtype=float)

	imwrite(os.path.join(args.output, 'original.png'),
		imformat(image[0, :, :, 0], vmin=vmin, vmax=vmax, symmetric=False))

	# remove center portion
	i_start = (image.shape[1] - args.fill_region) // 2
	j_start = (image.shape[2] - args.fill_region) // 2
	image[0,
		i_start:i_start + args.fill_region,
		j_start:j_start + args.fill_region, 0] = vmin + rand(args.fill_region, args.fill_region) * (vmax - vmin)

	imwrite(os.path.join(args.output, 'start.png'),
		imformat(image[0, :, :, 0], vmin=vmin, vmax=vmax, symmetric=False))


	### MODEL

	# load model
	model = Experiment(args.model)['model']
	model.verbosity = False

	# use different models for sampling and likelihoods because of SLSTM caching
	model_copy = deepcopy(model)

	# create mask indicating pixels to replace
	M = args.outer_patch_size
	N = args.inner_patch_size
	m = (M - N) // 2
	n = M - N - m
	patch_mask = zeros([M, M], dtype=bool)
	patch_mask[m:-n, m:-n] = True

	if args.init is None:
		candidates = []
		logliks = []

		for _ in range(args.candidates):
			# replace missing pixels by ancestral sampling
			patch = image[:,
				i_start - M:i_start + args.fill_region,
				j_start - M:j_start + args.fill_region + M]
			sample_mask = zeros([patch.shape[1], patch.shape[2]], dtype=bool)
			sample_mask[M:, M:-M] = True
			image[:,
				i_start - M:i_start + args.fill_region,
				j_start - M:j_start + args.fill_region + M] = model.sample(patch, mask=sample_mask,
					min_values=vmin, max_values=vmax)

			candidates.append(image.copy())
			logliks.append(model.loglikelihood(image).sum())

		image = candidates[argmax(logliks)]

		imwrite(os.path.join(args.output, 'fillin.0.png'),
			imformat(image[0, :, :, 0], vmin=vmin, vmax=vmax, symmetric=False))

		start_epoch = 0

	else:
		init = load(args.init)
		image = init['image']
		start_epoch = init['epoch']


	### INPAINTING

	try:
		for epoch in range(start_epoch, args.num_epochs):
			print epoch
			
			h_flipped = False
			if args.flip > 0 and rand() < .5:
				print 'Horizontal flip.'
				# flip image horizontally
				image = image[:, :, ::-1]
				j_start = image.shape[2] - j_start - args.fill_region
				h_flipped = True

			v_flipped = False
			if args.flip > 0 and rand() < .5:
				print 'Vertical flip.'
				# flip image vertically
				image = image[:, ::-1, :]
				i_start = image.shape[1] - i_start - args.fill_region
				v_flipped = True

			for i in range(i_start - m, i_start - m + args.fill_region - N + 1, args.stride):
				for j in range(j_start - m, j_start - m + args.fill_region - N + 1, args.stride):
					patch = image[:, i:i + M, j:j + M]

					if args.method == 'SAMPLE':
						# proposal
						patch_pr, logq_pr = model.sample(patch.copy(), mask=patch_mask,
							min_values=vmin, max_values=vmax, return_loglik=True)

						# conditional log-density
						logq = model_copy._logq(patch, patch_mask)

						# joint log-densities
						logp = model_copy.loglikelihood(patch).sum()
						logp_pr = model_copy.loglikelihood(patch_pr).sum()

						if rand() < exp(logp_pr - logp - logq_pr + logq):
							# accept proposal
							patch[:] = patch_pr

					else:
						# gradient step
						grad = model.gradient(patch)[1]
						patch[:, patch_mask] += grad[:, patch_mask] * args.step_width

			# flip back
			if h_flipped:
				image = image[:, :, ::-1]
				j_start = image.shape[2] - j_start - args.fill_region
			if v_flipped:
				image = image[:, ::-1, :]
				i_start = image.shape[1] - i_start - args.fill_region

			imwrite(os.path.join(args.output, 'fillin.{0}.png'.format(epoch + 1)),
				imformat(image[0, :, :, 0], vmin=vmin, vmax=vmax, symmetric=False))

	except KeyboardInterrupt:
		pass

	savez(os.path.join(args.output, 'final.npz'), image=image, epoch=epoch)

	return 0