Example #1
0
def _get_inference_data():
    # Setup global tensorflow state
    sess = _setup_tensorflow()

    # Load single image to use for inference
    if FLAGS.infile is None:
        raise ValueError(
            'Must specify inference input file through `--infile <filename>` command line argument'
        )

    if not tf.gfile.Exists(FLAGS.infile) or tf.gfile.IsDirectory(FLAGS.infile):
        raise FileNotFoundError('File `%s` does not exist or is a directory' %
                                (FLAGS.infile, ))

    filenames = [FLAGS.infile]
    infer_images = dm_input.input_data(sess, 'inference', filenames)

    print('Loading model...')
    # Create inference model
    infer_model = dm_model.create_model(sess, infer_images)

    # Load model parameters from checkpoint
    checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    try:
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint.model_checkpoint_path)
        del saver
        del checkpoint
    except:
        raise RuntimeError('Unable to read checkpoint from `%s`' %
                           (FLAGS.checkpoint_dir, ))
    print('Done.')

    # Pack all for convenience
    infer_data = dm_utils.Container(locals())

    return infer_data
Example #2
0
def _get_train_data(random_test_sample=True):
    # Setup global tensorflow state
    sess = _setup_tensorflow()

    # Prepare directories
    _prepare_train_dirs()

    # Which type of transformation?
    # Note: eyeglasses and sunglasses are filtered out because they tend to produce artifacts
    if FLAGS.train_mode == 'ftm' or FLAGS.train_mode == 'f2m':
        # Trans filter: from female to attractive male
        # Note: removed facial hair from target images because otherwise the network becomes overly focused on rendering facial hair
        # Note: also removed objects such as glasses and hats to avoid rendering these low gender influenced objects
        source_filter = {
            'Male': False,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
        target_filter = {
            'Male': True,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
    elif FLAGS.train_mode == 'mtf' or FLAGS.train_mode == 'm2f':
        # Trans filter: from male to attractuve female
        source_filter = {
            'Male': True,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
        target_filter = {
            'Male': False,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
    elif FLAGS.train_mode == 'ftf' or FLAGS.train_mode == 'f2f':
        # Vanity filter: from female to attractive female
        source_filter = {
            'Male': False,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
        target_filter = {
            'Male': False,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
    elif FLAGS.train_mode == "mtm" or FLAGS.train_mode == 'm2m':
        # Vanity filter: from male to attractive male
        source_filter = {
            'Male': True,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
        target_filter = {
            'Male': True,
            'Blurry': False,
            'Eyeglasses': False,
            'Attractive': True,
            'Eyeglasses': False,
            'Wearing_Hat': False,
            'Goatee': False,
            'Mustache': False,
            'No_Beard': True
        }
    else:
        raise ValueError(
            '`train_mode` must be one of: `ftm`, `mtf`, `ftf` or `mtm`')

    # Setup async input queues
    selected = dm_celeba.select_samples(source_filter)

    source_images = dm_input.input_data(sess, 'train', selected)
    if random_test_sample:
        test_images = dm_input.input_data(sess, 'test', selected)
    else:
        test_selected = dm_celeba.select_samples(source_filter,
                                                 random_sample=False)
        test_images = dm_input.input_data(sess, 'test', test_selected)
    print('%8d source images selected' % (len(selected), ))

    selected = dm_celeba.select_samples(target_filter)
    target_images = dm_input.input_data(sess, 'train', selected)
    print('%8d target images selected' % (len(selected), ))
    print()

    # Annealing temperature: starts at 1.0 and decreases exponentially over time
    annealing = tf.Variable(initial_value=1.0,
                            trainable=False,
                            name='annealing')
    halve_annealing = tf.assign(annealing, 0.5 * annealing)

    print('Using gan type: %s', specifiedgan)

    # Create and initialize training and testing models
    train_model = dm_model.create_model(sess,
                                        source_images,
                                        target_images,
                                        annealing,
                                        verbose=True,
                                        gan_type=specifiedgan)

    print("Building testing model...")
    test_model = dm_model.create_model(sess,
                                       test_images,
                                       None,
                                       annealing,
                                       gan_type=specifiedgan)
    print("Done.")

    # Forget this line and TF will deadlock at the beginning of training
    tf.train.start_queue_runners(sess=sess)

    # Pack all for convenience
    train_data = dm_utils.Container(locals())

    return train_data