def _get_inference_data(): # Setup global tensorflow state sess = _setup_tensorflow() # Load single image to use for inference if FLAGS.infile is None: raise ValueError( 'Must specify inference input file through `--infile <filename>` command line argument' ) if not tf.gfile.Exists(FLAGS.infile) or tf.gfile.IsDirectory(FLAGS.infile): raise FileNotFoundError('File `%s` does not exist or is a directory' % (FLAGS.infile, )) filenames = [FLAGS.infile] infer_images = dm_input.input_data(sess, 'inference', filenames) print('Loading model...') # Create inference model infer_model = dm_model.create_model(sess, infer_images) # Load model parameters from checkpoint checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) try: saver = tf.train.Saver() saver.restore(sess, checkpoint.model_checkpoint_path) del saver del checkpoint except: raise RuntimeError('Unable to read checkpoint from `%s`' % (FLAGS.checkpoint_dir, )) print('Done.') # Pack all for convenience infer_data = dm_utils.Container(locals()) return infer_data
def _get_train_data(random_test_sample=True): # Setup global tensorflow state sess = _setup_tensorflow() # Prepare directories _prepare_train_dirs() # Which type of transformation? # Note: eyeglasses and sunglasses are filtered out because they tend to produce artifacts if FLAGS.train_mode == 'ftm' or FLAGS.train_mode == 'f2m': # Trans filter: from female to attractive male # Note: removed facial hair from target images because otherwise the network becomes overly focused on rendering facial hair # Note: also removed objects such as glasses and hats to avoid rendering these low gender influenced objects source_filter = { 'Male': False, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } target_filter = { 'Male': True, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } elif FLAGS.train_mode == 'mtf' or FLAGS.train_mode == 'm2f': # Trans filter: from male to attractuve female source_filter = { 'Male': True, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } target_filter = { 'Male': False, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } elif FLAGS.train_mode == 'ftf' or FLAGS.train_mode == 'f2f': # Vanity filter: from female to attractive female source_filter = { 'Male': False, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } target_filter = { 'Male': False, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } elif FLAGS.train_mode == "mtm" or FLAGS.train_mode == 'm2m': # Vanity filter: from male to attractive male source_filter = { 'Male': True, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } target_filter = { 'Male': True, 'Blurry': False, 'Eyeglasses': False, 'Attractive': True, 'Eyeglasses': False, 'Wearing_Hat': False, 'Goatee': False, 'Mustache': False, 'No_Beard': True } else: raise ValueError( '`train_mode` must be one of: `ftm`, `mtf`, `ftf` or `mtm`') # Setup async input queues selected = dm_celeba.select_samples(source_filter) source_images = dm_input.input_data(sess, 'train', selected) if random_test_sample: test_images = dm_input.input_data(sess, 'test', selected) else: test_selected = dm_celeba.select_samples(source_filter, random_sample=False) test_images = dm_input.input_data(sess, 'test', test_selected) print('%8d source images selected' % (len(selected), )) selected = dm_celeba.select_samples(target_filter) target_images = dm_input.input_data(sess, 'train', selected) print('%8d target images selected' % (len(selected), )) print() # Annealing temperature: starts at 1.0 and decreases exponentially over time annealing = tf.Variable(initial_value=1.0, trainable=False, name='annealing') halve_annealing = tf.assign(annealing, 0.5 * annealing) print('Using gan type: %s', specifiedgan) # Create and initialize training and testing models train_model = dm_model.create_model(sess, source_images, target_images, annealing, verbose=True, gan_type=specifiedgan) print("Building testing model...") test_model = dm_model.create_model(sess, test_images, None, annealing, gan_type=specifiedgan) print("Done.") # Forget this line and TF will deadlock at the beginning of training tf.train.start_queue_runners(sess=sess) # Pack all for convenience train_data = dm_utils.Container(locals()) return train_data