def main(): xp = np using_gpu = args.gpu_device >= 0 if using_gpu: cuda.get_device(args.gpu_device).use() xp = cupy hyperparams = Hyperparameters(args.snapshot_path) hyperparams.print() num_bins_x = 2.0**hyperparams.num_bits_x encoder = Glow(hyperparams, hdf5_path=args.snapshot_path) if using_gpu: encoder.to_gpu() total = hyperparams.levels + 1 fig = plt.figure(figsize=(4 * total, 4)) subplots = [] for n in range(total): subplot = fig.add_subplot(1, total, n + 1) subplots.append(subplot) def reverse_step(z, sampling=True): if isinstance(z, list): factorized_z = z else: factorized_z = encoder.factor_z(z) assert len(factorized_z) == len(encoder.blocks) out = None sum_logdet = 0 for block, zi in zip(encoder.blocks[::-1], factorized_z[::-1]): out, logdet = block.reverse_step( out, gaussian_eps=zi, squeeze_factor=encoder.hyperparams.squeeze_factor, sampling=sampling) sum_logdet += logdet return out, sum_logdet with chainer.no_backprop_mode() and encoder.reverse() as decoder: while True: base_z = xp.random.normal(0, args.temperature, size=( 1, 3, ) + hyperparams.image_size, dtype="float32") factorized_z = encoder.factor_z(base_z) rev_x, _ = decoder.reverse_step(factorized_z) rev_x_img = make_uint8(rev_x.data[0], num_bins_x) subplots[0].imshow(rev_x_img, interpolation="none") z = xp.copy(base_z) factorized_z = encoder.factor_z(z) for n in range(hyperparams.levels - 1): factorized_z[n] = xp.random.normal(0, args.temperature, size=factorized_z[n].shape, dtype="float32") rev_x, _ = decoder.reverse_step(factorized_z) rev_x_img = make_uint8(rev_x.data[0], num_bins_x) subplots[1].imshow(rev_x_img, interpolation="none") # for n in range(hyperparams.levels): # z = xp.copy(base_z) # factorized_z = encoder.factor_z(z) # for m in range(n + 1): # factorized_z[m] = xp.random.normal( # 0, # args.temperature, # size=factorized_z[m].shape, # dtype="float32") # # factorized_z[m] = xp.zeros_like(factorized_z[m]) # out = None # for k, (block, zi) in enumerate( # zip(encoder.blocks[::-1], factorized_z[::-1])): # sampling = False # out, _ = block.reverse_step( # out, # gaussian_eps=zi, # squeeze_factor=encoder.hyperparams.squeeze_factor, # sampling=sampling) # rev_x = out # rev_x_img = make_uint8(rev_x.data[0], num_bins_x) # subplots[n + 1].imshow(rev_x_img, interpolation="none") for n in range(hyperparams.levels): z = xp.copy(base_z) factorized_z = encoder.factor_z(z) factorized_z[n] = xp.random.normal(0, args.temperature, size=factorized_z[n].shape, dtype="float32") factorized_z[n] = xp.zeros_like(factorized_z[n]) out = None for k, (block, zi) in enumerate( zip(encoder.blocks[::-1], factorized_z[::-1])): sampling = False if k == hyperparams.levels - n - 1 else True out, _ = block.reverse_step( out, gaussian_eps=zi, squeeze_factor=encoder.hyperparams.squeeze_factor, sampling=sampling) rev_x = out rev_x_img = make_uint8(rev_x.data[0], num_bins_x) subplots[n + 1].imshow(rev_x_img, interpolation="none") plt.pause(.01)
def main(): xp = np using_gpu = args.gpu_device >= 0 if using_gpu: cuda.get_device(args.gpu_device).use() xp = cupy hyperparams = Hyperparameters(args.snapshot_path) hyperparams.print() num_bins_x = 2.0**hyperparams.num_bits_x encoder = Glow(hyperparams, hdf5_path=args.snapshot_path) if using_gpu: encoder.to_gpu() total = hyperparams.levels fig = plt.figure(figsize=(4 * total, 4)) subplots = [] for n in range(total): subplot = fig.add_subplot(1, total, n + 1) subplots.append(subplot) with chainer.no_backprop_mode() and encoder.reverse() as decoder: while True: seed = int(time.time()) for level in range(1, hyperparams.levels): xp.random.seed(seed) z = xp.random.normal(0, args.temperature, size=( 1, 3, ) + hyperparams.image_size, dtype="float32") factorized_z = glow.nn.functions.factor_z( z, level + 1, squeeze_factor=hyperparams.squeeze_factor) out = glow.nn.functions.unsqueeze( factorized_z.pop(-1), factor=hyperparams.squeeze_factor, module=xp) for n, zi in enumerate(factorized_z[::-1]): block = encoder.blocks[level - n - 1] out, _ = block.reverse_step( out, gaussian_eps=zi, squeeze_factor=hyperparams.squeeze_factor) rev_x = out rev_x_img = make_uint8(rev_x.data[0], num_bins_x) subplot = subplots[level - 1] subplot.imshow(rev_x_img, interpolation="none") subplot.set_title("level = {}".format(level)) # original #levels xp.random.seed(seed) z = xp.random.normal(0, args.temperature, size=( 1, 3, ) + hyperparams.image_size, dtype="float32") factorized_z = encoder.factor_z(z) rev_x, _ = decoder.reverse_step(factorized_z) rev_x_img = make_uint8(rev_x.data[0], num_bins_x) subplot = subplots[-1] subplot.imshow(rev_x_img, interpolation="none") subplot.set_title("level = {}".format(hyperparams.levels)) plt.pause(.01)