def setup_snapshot_3d(G, training_set): num_fakes = 1 reals, label = training_set.get_minibatch_np(1) latents = misc.random_latents(num_fakes, G) labels = np.zeros([num_fakes, training_set.label_size], dtype=training_set.label_dtype) return reals, labels, latents
def setup_snapshot_image_grid( G, training_set, size='6by8', layout='random' ): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. # Select size. gw = 1 gh = 1 if size == '6by8': gw = 8 gh = 6 # Fill in reals and labels. reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) for idx in range(gw * gh): x = idx % gw y = idx // gw while True: real, label, _, _ = training_set.get_minibatch_np(1) if layout == 'row_per_class' and training_set.label_size > 0: if label[0, y % training_set.label_size] == 0.0: continue reals[idx] = real[0] labels[idx] = label[0] break # Generate latents. latents = misc.random_latents(gw * gh, G) return (gw, gh), reals, labels, latents
def generate(network_pkl, out_dir): if os.path.exists(out_dir): raise ValueError('{} already exists'.format(out_dir)) misc.init_output_logging() np.random.seed(config.random_seed) tfutil.init_tf(config.tf_config) with tf.device('/gpu:0'): G, D, Gs = misc.load_pkl(network_pkl) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset) # grid_size, grid_reals, grid_labels, grid_latents = train.setup_snapshot_image_grid(G, training_set, **config.grid) number_of_images = 1000 grid_labels = np.zeros([number_of_images, training_set.label_size], dtype=training_set.label_dtype) grid_latents = misc.random_latents(number_of_images, G) total_kimg = config.train.total_kimg sched = train.TrainingSchedule(total_kimg * 1000, training_set, **config.sched) grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch // config.num_gpus) os.makedirs(out_dir) # print(np.min(grid_fakes), np.mean(grid_fakes), np.max(grid_fakes)) # misc.save_image_grid(grid_fakes, 'fakes.png', drange=[-1,1], grid_size=grid_size) for i, img in enumerate(grid_fakes): img = img.transpose((1, 2, 0)) img = np.clip(img, -1, 1) img = (1 + img) / 2 img = skimage.img_as_ubyte(img) imageio.imwrite(os.path.join(out_dir, '{}.png'.format(i)), img[..., :3]) if img.shape[-1] > 3: np.save(os.path.join(out_dir, '{}.npy'.format(i)), img)
def generate_fake_images(model_path, out_dir, num_samples, random_seed=1000, image_shrink=1, minibatch_size=32): random_state = np.random.RandomState(random_seed) network_pkl = model_path print('Loading network from "%s"...' % network_pkl) G, D, Gs = misc.load_network_pkl(network_pkl) latents = misc.random_latents(num_samples, Gs, random_state) labels = np.zeros([latents.shape[0], 0], np.float32) images = Gs.run(latents, labels, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) save_dir = misc.make_dir(out_dir) misc.save_image_grid(images[:100], os.path.join(save_dir, 'samples.png'), [0, 255], [10, 10]) img_r01 = images.astype(np.float32) / 255. img_r01 = img_r01.transpose(0, 2, 3, 1) # NCHW => NHWC np.savez_compressed(os.path.join(save_dir, 'generated.npz'), noise=latents, img_r01=img_r01)
def setup_snapshot_image_grid(G, training_set, size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. # Select size. gw = 1; gh = 1 if size == '1080p': gw = np.clip(1920 // G.output_shape[3], 3, 32) gh = np.clip(1080 // G.output_shape[2], 2, 32) if size == '4k': gw = np.clip(3840 // G.output_shape[3], 7, 32) gh = np.clip(2160 // G.output_shape[2], 4, 32) # Fill in reals and labels. reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) for idx in range(gw * gh): x = idx % gw; y = idx // gw while True: real, label = training_set.get_minibatch_np(1) if layout == 'row_per_class' and training_set.label_size > 0: if label[0, y % training_set.label_size] == 0.0: continue reals[idx] = real[0] labels[idx] = label[0] break # Generate latents. latents = misc.random_latents(gw * gh, G) return (gw, gh), reals, labels, latents
def setup_snapshot_image_grid(G, training_set, size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. # Select size. gw = 1; gh = 1 if size == '1080p': gw = np.clip(1920 // G.output_shape[3], 3, 32) gh = np.clip(1080 // G.output_shape[2], 2, 32) if size == '4k': gw = np.clip(3840 // G.output_shape[3], 7, 32) gh = np.clip(2160 // G.output_shape[2], 4, 32) # Fill in reals and labels. reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) masks = np.zeros([gw * gh] + [1, training_set.shape[-1], training_set.shape[-1]], dtype=training_set.dtype) for idx in range(gw * gh): x = idx % gw; y = idx // gw while True: real, label, mask = training_set.get_minibatch_np(1) if layout == 'row_per_class' and training_set.label_size > 0: if label[0, y % training_set.label_size] == 0.0: continue reals[idx] = real[0] labels[idx] = label[0] masks[idx] = mask[0] break # Generate latents. latents = misc.random_latents(gw * gh, G) return (gw, gh), reals, labels, latents, masks
def generate_fake_images(pkl_path, out_dir, num_pngs, image_shrink=1, random_seed=1000, minibatch_size=1): random_state = np.random.RandomState(random_seed) if not os.path.isdir(out_dir): os.makedirs(out_dir) print('Loading network...') G, D, Gs = misc.load_network_pkl(pkl_path) latents = misc.random_latents(num_pngs, Gs, random_state=random_state) labels = np.zeros([latents.shape[0], 0], np.float32) images = Gs.run(latents, labels, minibatch_size=config.num_gpus * 256, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) for png_idx in range(num_pngs): print('Generating png to %s: %d / %d...' % (out_dir, png_idx, num_pngs), end='\r') if not os.path.exists( os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx)): misc.save_image_grid( images[png_idx:png_idx + 1], os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx), [0, 255], [1, 1]) print()
def hello(): tfutil.init_tf(config.tf_config) with tf.device('/gpu:0'): G, D, Gs = misc.load_pkl(resume_network_pkl) imsize = Gs.output_shape[-1] selected_textures = misc.random_latents(1, Gs) selected_shapes = get_random_mask(1) selected_colors = get_random_color(1) fake_images = Gs.run(selected_textures, selected_colors, selected_shapes) return "DCGAN endpoint -> /predict "
def generate_fake_images_all(run_id, out_dir, num_pngs, image_shrink=1, random_seed=1000, minibatch_size=1, num_pkls=50): random_state = np.random.RandomState(random_seed) out_dir = os.path.join(out_dir, str(run_id)) result_subdir = misc.locate_result_subdir(run_id) snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False) assert len(snapshot_pkls) >= 1 for snapshot_idx, snapshot_pkl in enumerate(snapshot_pkls[:num_pkls]): prefix = 'network-snapshot-' postfix = '.pkl' snapshot_name = os.path.basename(snapshot_pkl) tmp_dir = os.path.join(out_dir, snapshot_name.split('.')[0]) if not os.path.isdir(tmp_dir): os.makedirs(tmp_dir) assert snapshot_name.startswith(prefix) and snapshot_name.endswith( postfix) snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)]) print('Loading network...') G, D, Gs = misc.load_network_pkl(snapshot_pkl) latents = misc.random_latents(num_pngs, Gs, random_state=random_state) labels = np.zeros([latents.shape[0], 0], np.float32) images = Gs.run(latents, labels, minibatch_size=config.num_gpus * 32, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) for png_idx in range(num_pngs): print('Generating png to %s: %d / %d...' % (tmp_dir, png_idx, num_pngs), end='\r') if not os.path.exists( os.path.join(out_dir, 'ProGAN_%08d.png' % png_idx)): misc.save_image_grid( images[png_idx:png_idx + 1], os.path.join(tmp_dir, 'ProGAN_%08d.png' % png_idx), [0, 255], [1, 1]) print()
def predict(): tfutil.init_tf(config.tf_config) with tf.device('/gpu:0'): G, D, Gs = misc.load_pkl(resume_network_pkl) imsize = Gs.output_shape[-1] random_masks = [] temp = Image.open(request.files['image']).convert('L') temp = temp.resize((imsize, imsize)) temp = (np.float32(temp) - 127.5) / 127.5 temp = temp.reshape((1, 1, imsize, imsize)) random_masks.append(temp) masks = np.vstack(random_masks) #masks = get_random_mask(1) ctemp = [] ctemp.append(float(request.form['R'])) ctemp.append(float(request.form['G'])) ctemp.append(float(request.form['B'])) colors = np.array([ctemp], dtype=object) #colors = get_random_color(1) texid = -1 selected_textures = None if request.form['texflag'] == "true": selected_textures = misc.random_latents(1, Gs) texture_list.append(selected_textures[0]) texid = len(texture_list) - 1 else: selected_textures = np.array( [texture_list[int(request.form['texid'])]], dtype=object) texid = int(request.form['texid']) #selected_textures = misc.random_latents(1, Gs) fake_images = Gs.run(selected_textures, colors, masks) fake_images = convert_to_image(fake_images) matplotlib.image.imsave('localtemp.png', fake_images[0]) conv_image = Image.open('localtemp.png') buffered = io.BytesIO() conv_image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()) #jsonify({"image": str(img_str), "id": texid}) return jsonify({ "image": str(img_str)[2:-1], "id": texid })
def setup_snapshot_image_grid( G, training_set, size='6by8', # '6by8'=6row and 8 column, '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display. layout='random' ): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label. # Select size. gw = 1 gh = 1 if size == '1080p': gw = np.clip(1920 // G.output_shape[3], 3, 32) gh = np.clip(1080 // G.output_shape[2], 2, 32) if size == '4k': gw = np.clip(3840 // G.output_shape[3], 7, 32) gh = np.clip(2160 // G.output_shape[2], 4, 32) if size == '6by8': gw = 8 gh = 6 # Fill in reals and labels. reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) wellfacies = np.zeros([gw * gh] + training_set.shape, dtype=np.float32) probimages = np.zeros([gw * gh] + training_set.shape, dtype=np.float32) for idx in range(gw * gh): x = idx % gw y = idx // gw while True: real, label, probimage, wellface = training_set.get_minibatch_np( 1) # if layout == 'row_per_class' and training_set.label_size > 0: if label[0, y % training_set.label_size] == 0.0: continue reals[idx] = real[0] labels[idx] = label[0] wellfacies[idx] = wellface[0] probimages[idx] = probimage[0] break # Generate latents. latents = misc.random_latents(gw * gh, G) return (gw, gh), reals, labels, wellfacies, latents, probimages
def f_z_generator(CNN_INPUT_SIZE): """ Generates an infinite stream of (z, G(z), F(G(z))) pairs, meaning latent points in the GAN's generator's latent space, and the features of the image generated by the GAN. """ # Import official CelebA-HQ networks. with open('models/pg_gan/karras2018iclr-celebahq-1024x1024.pkl', 'rb') as file: G, D, Gs = pickle.load(file) mark_detector = MarkDetector() age_net = cv2.dnn.readNetFromCaffe('models/race_age/deploy_age.prototxt', 'models/race_age/age_net.caffemodel') gender_net = cv2.dnn.readNetFromCaffe( 'models/race_age/deploy_gender.prototxt', 'models/race_age/gender_net.caffemodel') #pose_estimator = PoseEstimator(img_size=(height, width)) i = 0 while True: i += 1 z = misc.random_latents(1, G) labels = np.zeros([z.shape[0], 0], np.float32) x = G.run(z, labels, out_mul=127.5, out_add=127.5, out_dtype=np.uint8) #f = [calculate_facial_features(single_img, CNN_INPUT_SIZE, mark_detector) for single_img in xs] #print(x.shape) x = np.squeeze(x) x = np.transpose(x, (1, 2, 0)) x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) #print("Trying face...") facebox = mark_detector.extract_cnn_facebox(x) if facebox is None: continue f = calculate_facial_features(x, facebox, CNN_INPUT_SIZE, mark_detector, age_net, gender_net) #print(x.shape) #cv2.imwrite("inverter_images/output_{}.png".format(i), x) #print("Found face...") #f = tf.map_fn(lambda single_img: , x) yield (z, x, f.reshape(1, -1))
def setup_snapshot_image_grid(G, training_set, size='6by8'): # '6by8'=6row and 8 column. gw = 8 gh = 6 # Fill in reals and labels. reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) for idx in range(gw * gh): x = idx % gw y = idx // gw while True: real, label = training_set.get_minibatch_np(1) reals[idx] = real[0] labels[idx] = label[0] break # Generate latents. latents = misc.random_latents(gw * gh, G) return (gw, gh), reals, labels, latents
os.makedirs(args.data_dir, exist_ok=True) os.makedirs(args.mask_dir, exist_ok=True) os.makedirs(args.generated_images_dir, exist_ok=True) os.makedirs(args.dlatent_dir, exist_ok=True) os.makedirs(args.dlabel_dir, exist_ok=True) # Initialize generator and perceptual model # load network network_pkl = misc.locate_network_pkl(args.results_dir) print('Loading network from "%s"...' % network_pkl) G, D, Gs = misc.load_network_pkl(args.results_dir, None) # initiate random input latents = misc.random_latents(1, Gs, random_state=np.random.RandomState(800)) labels = np.random.rand(1, args.labels_size) generator = Generator(Gs, labels_size=572, batch_size=1, clipping_threshold=args.clipping_threshold, model_res=args.resolution) perc_model = None if (args.use_lpips_loss > 0.00000001): with open(args.load_perc_model, "rb") as f: perc_model = pickle.load(f) ff_model = None beautyrater_model = beautyrater.BeautyRater(args.load_vgg_beauty_rater_model)
def evaluate_metrics(run_id, log, metrics, num_images, real_passes, minibatch_size=None): metric_class_names = { 'swd': 'metrics.sliced_wasserstein.API', 'fid': 'metrics.frechet_inception_distance.API', 'is': 'metrics.inception_score.API', 'msssim': 'metrics.ms_ssim.API', } # Locate training run and initialize logging. result_subdir = misc.locate_result_subdir(run_id) snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False) assert len(snapshot_pkls) >= 1 log_file = os.path.join(result_subdir, log) print('Logging output to', log_file) misc.set_output_log_file(log_file) # Initialize dataset and select minibatch size. dataset_obj, mirror_augment = misc.load_dataset_for_previous_run( result_subdir, verbose=True, shuffle_mb=0) if minibatch_size is None: minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256) # Initialize metrics. metric_objs = [] for name in metrics: class_name = metric_class_names.get(name, name) print('Initializing %s...' % class_name) class_def = tfutil.import_obj(class_name) image_shape = [3] + dataset_obj.shape[1:] obj = class_def(num_images=num_images, image_shape=image_shape, image_dtype=np.uint8, minibatch_size=minibatch_size) tfutil.init_uninited_vars() mode = 'warmup' obj.begin(mode) for idx in range(10): obj.feed( mode, np.random.randint(0, 256, size=[minibatch_size] + image_shape, dtype=np.uint8)) obj.end(mode) metric_objs.append(obj) # Print table header. print() print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='') for obj in metric_objs: for name, fmt in zip(obj.get_metric_names(), obj.get_metric_formatting()): print('%-*s' % (len(fmt % 0), name), end='') print() print('%-10s%-12s' % ('---', '---'), end='') for obj in metric_objs: for fmt in obj.get_metric_formatting(): print('%-*s' % (len(fmt % 0), '---'), end='') print() # Feed in reals. for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]: print('%-10s' % title, end='') time_begin = time.time() labels = np.zeros([num_images, dataset_obj.label_size], dtype=np.float32) [obj.begin(mode) for obj in metric_objs] for begin in range(0, num_images, minibatch_size): end = min(begin + minibatch_size, num_images) images, labels[begin:end] = dataset_obj.get_minibatch_np(end - begin) if mirror_augment: images = misc.apply_mirror_augment(images) if images.shape[1] == 1: images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB [obj.feed(mode, images) for obj in metric_objs] results = [obj.end(mode) for obj in metric_objs] print('%-12s' % misc.format_time(time.time() - time_begin), end='') for obj, vals in zip(metric_objs, results): for val, fmt in zip(vals, obj.get_metric_formatting()): print(fmt % val, end='') print() # Evaluate each network snapshot. for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)): prefix = 'network-snapshot-' postfix = '.pkl' snapshot_name = os.path.basename(snapshot_pkl) assert snapshot_name.startswith(prefix) and snapshot_name.endswith( postfix) snapshot_kimg = int(snapshot_name[len(prefix):-len(postfix)]) print('%-10d' % snapshot_kimg, end='') mode = 'fakes' [obj.begin(mode) for obj in metric_objs] time_begin = time.time() with tf.Graph().as_default(), tfutil.create_session( config.tf_config).as_default(): G, D, Gs = misc.load_pkl(snapshot_pkl) for begin in range(0, num_images, minibatch_size): end = min(begin + minibatch_size, num_images) latents = misc.random_latents(end - begin, Gs) images = Gs.run(latents, labels[begin:end], num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_dtype=np.uint8) if images.shape[1] == 1: images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB [obj.feed(mode, images) for obj in metric_objs] results = [obj.end(mode) for obj in metric_objs] print('%-12s' % misc.format_time(time.time() - time_begin), end='') for obj, vals in zip(metric_objs, results): for val, fmt in zip(vals, obj.get_metric_formatting()): print(fmt % val, end='') print() print()
def __init__(self, model, labels_size=572, batch_size=1, clipping_threshold=1, model_res=128): self.batch_size = batch_size self.clipping_threshold = clipping_threshold self.initial_dlatents = misc.random_latents( 1, model, random_state=np.random.RandomState( 800)) #np.zeros((self.batch_size, 512)) self.initial_dlabels = np.random.rand(self.batch_size, labels_size) self.sess = tf.get_default_session() self.graph = tf.get_default_graph() def get_tensor(name): try: return self.graph.get_tensor_by_name(name) except KeyError: return None self.dlatent_variable = tf.get_variable( 'learnable_dlatents', shape=(batch_size, 512), dtype='float32', initializer=tf.initializers.random_normal()) self.dlabel_variable = tf.get_variable( 'learnable_dlabels', shape=(batch_size, labels_size), dtype='float32', initializer=tf.initializers.random_normal()) self.generator_output = model.get_output_for(self.dlatent_variable, self.dlabel_variable) self.latents_name_tensor = get_tensor(model.input_templates[0].name) self.labels_name_tensor = get_tensor(model.input_templates[1].name) self.output_name_tensor = get_tensor(model.output_templates[0].name) self.output_name_image = tflib.convert_images_to_uint8( self.output_name_tensor, nchw_to_nhwc=True, uint8_cast=False) self.output_name_image_uint8 = tf.saturate_cast( self.output_name_image, tf.uint8) self.set_dlatents(self.initial_dlatents) self.set_dlabels(self.initial_dlabels) self.generator_output_shape = model.output_shape if self.generator_output is None: for op in self.graph.get_operations(): print(op) raise Exception("Couldn't find generator_output") self.generated_image = tflib.convert_images_to_uint8( self.generator_output, nchw_to_nhwc=True, uint8_cast=False) self.generated_image_uint8 = tf.saturate_cast(self.generated_image, tf.uint8) # Implement stochastic clipping similar to what is described in https://arxiv.org/abs/1702.04782 # (Slightly different in that the latent space is normal gaussian here and was uniform in [-1, 1] in that paper, # so we clip any vector components outside of [-2, 2]. It seems fine, but I haven't done an ablation check.) clipping_mask1 = tf.math.logical_or( self.dlatent_variable > self.clipping_threshold, self.dlatent_variable < -self.clipping_threshold) clipped_values1 = tf.where( clipping_mask1, tf.random_normal(shape=(self.batch_size, 512)), self.dlatent_variable) self.stochastic_clip_op1 = tf.assign(self.dlatent_variable, clipped_values1) clipping_mask2_1 = tf.math.logical_or( self.dlabel_variable[:, 0:60] > self.clipping_threshold, self.dlabel_variable[:, 0:60] < 0) clipping_mask2_2 = tf.math.logical_or( self.dlabel_variable[:, 60:] > self.clipping_threshold, self.dlabel_variable[:, 60:] < -self.clipping_threshold) clipping_mask2 = tf.concat([clipping_mask2_1, clipping_mask2_2], axis=1) clipped_values2 = tf.where( clipping_mask2, tf.random_normal(shape=(self.batch_size, labels_size)), self.dlabel_variable) self.stochastic_clip_op2 = tf.assign(self.dlabel_variable, clipped_values2)
def recovery(name, pkl_path1, pkl_path2, out_dir, target_latents_dir, num_init=20, num_total_sample=100, image_shrink=1, random_seed=2020, minibatch_size=1, noise_sigma=0): # misc.init_output_logging() # np.random.seed(random_seed) # print('Initializing TensorFlow...') # os.environ.update(config.env) # tfutil.init_tf(config.tf_config) print('num_init:' + str(num_init)) # load sorce model print('Loading network1...' + pkl_path1) _, _, G_sorce = misc.load_network_pkl(pkl_path1) # load target model print('Loading network2...' + pkl_path2) _, _, G_target = misc.load_network_pkl(pkl_path2) # load Gt Gt = tfutil.Network('Gt', num_samples=num_init, num_channels=3, resolution=128, func='networks.G_recovery') latents = misc.random_latents(num_init, Gt, random_state=None) labels = np.zeros([latents.shape[0], 0], np.float32) Gt.copy_vars_from_with_input(G_target, latents) # load Gs Gs = tfutil.Network('Gs', num_samples=num_init, num_channels=3, resolution=128, func='networks.G_recovery') Gs.copy_vars_from_with_input(G_sorce, latents) out_dir = os.path.join(out_dir, name) if not os.path.exists(out_dir): os.mkdir(out_dir) def G_loss(G, target_images): tmp_latents = tfutil.run(G.trainables['Input/weight']) G_out = G.get_output_for(tmp_latents, labels, is_training=True) G_out = rescale_output(G_out) return tf.losses.mean_squared_error(target_images, G_out) z_init = [] z_recovered = [] #load target z if target_latents_dir is not None: print('using latents:' + target_latents_dir) pre_latents = np.load(target_latents_dir) for k in range(num_total_sample): result_dir = os.path.join(out_dir, str(k) + '.png') #============sample target image if target_latents_dir is not None: latent = pre_latents[k] else: latents = misc.random_latents(1, Gs, random_state=None) latent = latents[0] z_init.append(latent) latents = np.zeros((num_init, 512)) for i in range(num_init): latents[i] = latent Gt.change_input(inputs=latents) #================add_noise target_images = Gt.get_output_for(latents, labels, is_training=False) target_images_tf = rescale_output(target_images) target_images = tfutil.run(target_images_tf) target_images_noise = addGaussianNoise(target_images, sigma=noise_sigma) target_images_noise = tf.cast(target_images_noise, dtype='float32') target_images = target_images_noise #=============select random start point latents_2 = misc.random_latents(num_init, Gs, random_state=None) Gs.change_input(inputs=latents_2) #==============define loss&optimizer regularizer = tf.abs(tf.norm(latents_2) - np.sqrt(512)) loss = G_loss(G=Gs, target_images=target_images) # + regularizer # init_var = OrderedDict([('Input/weight',Gs.trainables['Input/weight'])]) # decayed_lr = tf.train.exponential_decay(0.1,500, 50, 0.5, staircase=True) G_opt = tfutil.Optimizer(name='latent_recovery', learning_rate=0.01) G_opt.register_gradients(loss, Gs.trainables) G_train_op = G_opt.apply_updates() #===========recovery========== EPOCH = 500 losses = [] losses.append(tfutil.run(loss)) for i in range(EPOCH): G_opt.reset_optimizer_state() tfutil.run([G_train_op]) ######## learned_latent = tfutil.run(Gs.trainables['Input/weight']) result_images = Gs.run(learned_latent, labels, minibatch_size=config.num_gpus * 256, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.float32) sample_losses = [] tmp_latents = tfutil.run(Gs.trainables['Input/weight']) G_out = Gs.get_output_for(tmp_latents, labels, is_training=True) G_out = rescale_output(G_out) for i in range(num_init): loss = tf.losses.mean_squared_error(target_images[i], G_out[i]) sample_losses.append(tfutil.run(loss)) #========save best optimized image plt.subplot(1, 2, 1) plt.imshow(tfutil.run(target_images)[0].transpose(1, 2, 0) / 255.0) plt.subplot(1, 2, 2) plt.imshow(result_images[np.argmin(sample_losses)].transpose(1, 2, 0) / 255.0) plt.savefig(result_dir) #========store optimized z z_recovered.append(tmp_latents) #=========save losses # loss=min(sample_losses) with open(out_dir + "/losses.txt", "a") as f: for loss in sample_losses: f.write(str(loss) + ' ') f.write('\n') np.save(out_dir + '/z_init', np.array(z_init)) np.save(out_dir + '/z_re', np.array(z_recovered))