def infer(self, source_obj, model_dir, save_dir): source_provider = InjectDataProvider(source_obj) source_iter = source_provider.get_iter(self.batch_size) tf.global_variables_initializer().run() saver = tf.train.Saver(var_list=self.retrieve_generator_vars()) self.restore_model(saver, model_dir) def save_imgs(imgs, count): p = os.path.join(save_dir, "inferred_%04d.png" % count) save_concat_images(imgs, img_path=p) print("generated images saved at %s" % p) count = 0 batch_buffer = list() for source_imgs in source_iter: fake_imgs = self.generate_fake_samples(source_imgs)[0] merged_fake_images = merge(scale_back(fake_imgs), [self.batch_size, 1]) batch_buffer.append(merged_fake_images) if len(batch_buffer) == 10: save_imgs(batch_buffer, count) batch_buffer = list() count += 1 if batch_buffer: # last batch save_imgs(batch_buffer, count)
def main(_): config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: source_provider = InjectDataProvider(args.source_obj) source_len = len(source_provider.data.examples) source_len = min(10, source_len) model = Font2Font(batch_size=source_len) model.register_session(sess) model.build_model(is_training=False) model.test(source_provider, args.model_dir, args.save_dir)