Ejemplo n.º 1
0
    def save_data(self, x, name, save_dir, x_is_real_x=False):
        """Save dataspace instances.

    Args:
      x: A numpy array of dataspace points.
      name: A string indicating the name in the saved file.
      save_dir: A string indicating the directory to put the saved file.
      x_is_real_x: An boolean indicating whether `x` is already in dataspace. If
          not, `x` is converted to dataspace before saving
    """
        real_x = x if x_is_real_x else self.decode(x)
        real_x = common.post_proc(real_x, self.config)
        batched_real_x = common.batch_image(real_x)
        sample_file = join(save_dir, '%s.png' % name)
        common.save_image(batched_real_x, sample_file)
Ejemplo n.º 2
0
def main(unused_argv):
    del unused_argv

    # Load Config
    config_name = FLAGS.config
    config_module = importlib.import_module('configs.%s' % config_name)
    config = config_module.config
    model_uid = common.get_model_uid(config_name, FLAGS.exp_uid)
    n_latent = config['n_latent']

    # Load dataset
    dataset = common.load_dataset(config)
    basepath = dataset.basepath
    save_path = dataset.save_path
    train_data = dataset.train_data

    # Make the directory
    save_dir = os.path.join(save_path, model_uid)
    best_dir = os.path.join(save_dir, 'best')
    tf.gfile.MakeDirs(save_dir)
    tf.gfile.MakeDirs(best_dir)
    tf.logging.info('Save Dir: %s', save_dir)

    # Set random seed
    np.random.seed(FLAGS.random_seed)
    tf.set_random_seed(FLAGS.random_seed)

    # Load Model
    tf.reset_default_graph()
    sess = tf.Session()
    with tf.device(tf.train.replica_device_setter(ps_tasks=0)):
        m = model_dataspace.Model(config, name=model_uid)
        _ = m()  # noqa

        # Initialize
        sess.run(tf.global_variables_initializer())

        # Load
        m.vae_saver.restore(
            sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid))

        # Sample from prior
        sample_count = 64

        image_path = os.path.join(basepath, 'sample', model_uid)
        tf.gfile.MakeDirs(image_path)

        # from prior
        z_p = np.random.randn(sample_count, m.n_latent)
        x_p = sess.run(m.x_mean, {m.z: z_p})
        x_p = common.post_proc(x_p, config)
        common.save_image(common.batch_image(x_p),
                          os.path.join(image_path, 'sample_prior.png'))

        # Sample from priro, as Grid
        boundary = 2.0
        number_grid = 50
        blob = common.make_grid(boundary=boundary,
                                number_grid=number_grid,
                                dim_latent=n_latent)
        z_grid, dim_grid = blob.z_grid, blob.dim_grid
        x_grid = sess.run(m.x_mean, {m.z: z_grid})
        x_grid = common.post_proc(x_grid, config)
        batch_image_grid = common.make_batch_image_grid(dim_grid, number_grid)
        common.save_image(batch_image_grid(x_grid),
                          os.path.join(image_path, 'sample_grid.png'))

        # Reconstruction
        sample_count = 64
        x_real = train_data[:sample_count]
        mu, sigma = sess.run([m.mu, m.sigma], {m.x: x_real})
        x_rec = sess.run(m.x_mean, {m.mu: mu, m.sigma: sigma})
        x_rec = common.post_proc(x_rec, config)

        x_real = common.post_proc(x_real, config)
        common.save_image(common.batch_image(x_real),
                          os.path.join(image_path, 'image_real.png'))
        common.save_image(common.batch_image(x_rec),
                          os.path.join(image_path, 'image_rec.png'))
Ejemplo n.º 3
0
def main(unused_argv):
    # pylint:disable=unused-variable
    # Reason:
    #   This training script relys on many programmatical call to function and
    #   access to variables. Pylint cannot infer this case so it emits false alarm
    #   of unused-variable if we do not disable this warning.

    # pylint:disable=invalid-name
    # Reason:
    #   Following variables have their name consider to be invalid by pylint so
    #   we disable the warning.
    #   - Variable that in its name has A or B indictating their belonging of
    #     one side of data.
    del unused_argv

    # Load main config
    config_name = FLAGS.config
    config = load_config(config_name)

    config_name_A = config['config_A']
    config_name_B = config['config_B']
    config_name_classifier_A = config['config_classifier_A']
    config_name_classifier_B = config['config_classifier_B']

    # Load dataset
    dataset_A = common_joint.load_dataset(config_name_A, FLAGS.exp_uid_A)
    (dataset_blob_A, train_data_A, train_label_A, train_mu_A, train_sigma_A,
     index_grouped_by_label_A) = dataset_A
    dataset_B = common_joint.load_dataset(config_name_B, FLAGS.exp_uid_B)
    (dataset_blob_B, train_data_B, train_label_B, train_mu_B, train_sigma_B,
     index_grouped_by_label_B) = dataset_B

    # Prepare directories
    dirs = common_joint.prepare_dirs('joint', config_name, FLAGS.exp_uid)
    save_dir, sample_dir = dirs

    # Set random seed
    np.random.seed(FLAGS.random_seed)
    tf.set_random_seed(FLAGS.random_seed)

    # Real Training.
    tf.reset_default_graph()
    sess = tf.Session()

    # Load model's architecture (= build)
    one_side_helper_A = common_joint.OneSideHelper(config_name_A,
                                                   FLAGS.exp_uid_A,
                                                   config_name_classifier_A,
                                                   FLAGS.exp_uid_classifier)
    one_side_helper_B = common_joint.OneSideHelper(config_name_B,
                                                   FLAGS.exp_uid_B,
                                                   config_name_classifier_B,
                                                   FLAGS.exp_uid_classifier)
    m = common_joint.load_model(model_joint.Model, config_name, FLAGS.exp_uid)

    # Initialize and restore
    sess.run(tf.global_variables_initializer())

    one_side_helper_A.restore(dataset_blob_A)
    one_side_helper_B.restore(dataset_blob_B)

    # Restore from ckpt
    config_name = FLAGS.config
    model_uid = common.get_model_uid(config_name, FLAGS.exp_uid)
    save_name = join(save_dir,
                     'transfer_%s_%d.ckpt' % (model_uid, FLAGS.load_ckpt_iter))
    m.vae_saver.restore(sess, save_name)

    # prepare intepolate dir
    intepolate_dir = join(sample_dir, 'interpolate_sample',
                          '%010d' % FLAGS.load_ckpt_iter)
    tf.gfile.MakeDirs(intepolate_dir)

    # things
    interpolate_labels = [int(_) for _ in FLAGS.interpolate_labels.split(',')]
    nb_images_between_labels = FLAGS.nb_images_between_labels

    index_list_A = []
    last_pos = [0] * 10
    for label in interpolate_labels:
        index_list_A.append(index_grouped_by_label_A[label][last_pos[label]])
        last_pos[label] += 1

    index_list_B = []
    last_pos = [-1] * 10
    for label in interpolate_labels:
        index_list_B.append(index_grouped_by_label_B[label][last_pos[label]])
        last_pos[label] -= 1

    z_A = []
    z_A.append(train_mu_A[index_list_A[0]])
    for i_label in range(1, len(interpolate_labels)):
        last_z_A = z_A[-1]
        this_z_A = train_mu_A[index_list_A[i_label]]
        for j in range(1, nb_images_between_labels + 1):
            z_A.append(last_z_A + (this_z_A - last_z_A) *
                       (float(j) / nb_images_between_labels))
    z_B = []
    z_B.append(train_mu_B[index_list_B[0]])
    for i_label in range(1, len(interpolate_labels)):
        last_z_B = z_B[-1]
        this_z_B = train_mu_B[index_list_B[i_label]]
        for j in range(1, nb_images_between_labels + 1):
            z_B.append(last_z_B + (this_z_B - last_z_B) *
                       (float(j) / nb_images_between_labels))
    z_B_tr = []
    for this_z_A in z_A:
        this_z_B_tr = sess.run(m.x_A_to_B_direct,
                               {m.x_A: np.array([this_z_A])})
        z_B_tr.append(this_z_B_tr[0])

    # Generate data domain instances and save.
    z_A = np.array(z_A)
    x_A = one_side_helper_A.m_helper.decode(z_A)
    x_A = common.post_proc(x_A, one_side_helper_A.m_helper.config)
    batched_x_A = common.batch_image(
        x_A,
        max_images=len(x_A),
        rows=len(x_A),
        cols=1,
    )
    common.save_image(batched_x_A, join(intepolate_dir, 'x_A.png'))

    z_B = np.array(z_B)
    x_B = one_side_helper_B.m_helper.decode(z_B)
    x_B = common.post_proc(x_B, one_side_helper_B.m_helper.config)
    batched_x_B = common.batch_image(
        x_B,
        max_images=len(x_B),
        rows=len(x_B),
        cols=1,
    )
    common.save_image(batched_x_B, join(intepolate_dir, 'x_B.png'))

    z_B_tr = np.array(z_B_tr)
    x_B_tr = one_side_helper_B.m_helper.decode(z_B_tr)
    x_B_tr = common.post_proc(x_B_tr, one_side_helper_B.m_helper.config)
    batched_x_B_tr = common.batch_image(
        x_B_tr,
        max_images=len(x_B_tr),
        rows=len(x_B_tr),
        cols=1,
    )
    common.save_image(batched_x_B_tr, join(intepolate_dir, 'x_B_tr.png'))