def load_dataset_wavegan():
    """Load WaveGAN's dataset.

  The loaded dataset consists of:
    - original data (dataset_blob, train_data, train_label),
    - encoded data from a pretrained model (train_mu, train_sigma), and
    - index grouped by label (index_grouped_by_label).

  Some of these attributes are not avaiable (set as None) but are left here
  to keep everything aligned with returned value of `load_dataset`.

  Returns:
    An tuple of abovementioned components in the dataset.
  """

    latent_dir = os.path.expanduser(FLAGS.wavegan_latent_dir)
    path_train = os.path.join(latent_dir, 'data_train.npz')
    train = np.load(path_train)
    train_z = train['z']
    train_label = train['label']
    index_grouped_by_label = common.get_index_grouped_by_label(train_label)

    dataset_blob, train_data = None, None
    train_mu, train_sigma = train_z, None
    return (dataset_blob, train_data, train_label, train_mu, train_sigma,
            index_grouped_by_label)
def load_dataset(config_name, exp_uid):
  """Load a dataset from a config's name.

  The loaded dataset consists of:
    - original data (dataset, train_data, train_label),
    - encoded data from a pretrained model (train_mu, train_sigma), and
    - index grouped by label (index_grouped_by_label).
    - path of saving (save_path) for restoring pre-trained models.

  Args:
    config_name: A string indicating the name of config to parameterize the
        model that associates with the dataset.
    exp_uid: A string representing the unique id of experiment to be used in
        model that associates with the dataset.

  Returns:
    A DatasetBlob of abovementioned components in the dataset.
  """

  config = common.load_config(config_name)
  this_config_is_wavegan = common.config_is_wavegan(config)
  if this_config_is_wavegan:
    return load_dataset_wavegan()

  model_uid = common.get_model_uid(config_name, exp_uid)

  dataset = common.load_dataset(config)
  train_data = dataset.train_data
  attr_train = dataset.attr_train
  save_path = dataset.save_path
  path_train = join(dataset.basepath, 'encoded', model_uid,
                    'encoded_train_data.npz')
  train = np.load(path_train)
  train_mu = train['mu']
  train_sigma = train['sigma']
  train_label = np.argmax(attr_train, axis=-1)  # from one-hot to label
  index_grouped_by_label = common.get_index_grouped_by_label(train_label)

  tf.logging.info('index_grouped_by_label size: %s',
                  [len(_) for _ in index_grouped_by_label])

  tf.logging.info('train loaded from %s', path_train)
  tf.logging.info('train shapes: mu = %s, sigma = %s', train_mu.shape,
                  train_sigma.shape)
  return DatasetBlob(
      train_data=train_data,
      train_label=train_label,
      train_mu=train_mu,
      train_sigma=train_sigma,
      index_grouped_by_label=index_grouped_by_label,
      save_path=save_path,
  )
Exemple #3
0
def main(unused_argv):
    del unused_argv

    # Load Config
    config_name = FLAGS.config
    config_module = importlib.import_module('configs.%s' % config_name)
    config = config_module.config
    model_uid = common.get_model_uid(config_name, FLAGS.exp_uid)
    n_latent = config['n_latent']

    # Load dataset
    dataset = common.load_dataset(config)
    basepath = dataset.basepath
    save_path = dataset.save_path
    train_data = dataset.train_data
    attr_train = dataset.attr_train

    train_label = np.argmax(attr_train, axis=-1)  # from one-hot to label
    index_grouped_by_label = common.get_index_grouped_by_label(train_label)

    # Make the directory
    save_dir = os.path.join(save_path, model_uid)
    best_dir = os.path.join(save_dir, 'best')
    tf.gfile.MakeDirs(save_dir)
    tf.gfile.MakeDirs(best_dir)
    tf.logging.info('Save Dir: %s', save_dir)

    # Set random seed
    np.random.seed(FLAGS.random_seed)
    tf.set_random_seed(FLAGS.random_seed)

    # Load Model
    tf.reset_default_graph()
    sess = tf.Session()
    with tf.device(tf.train.replica_device_setter(ps_tasks=0)):
        m = model_dataspace.Model(config, name=model_uid)
        _ = m()  # noqa

        # Initialize
        sess.run(tf.global_variables_initializer())

        # Load
        m.vae_saver.restore(
            sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid))

        # Sample from prior
        sample_count = 60

        image_path = os.path.join(basepath, 'sample', model_uid)
        tf.gfile.MakeDirs(image_path)

        # from prior
        z_p = np.random.randn(sample_count, m.n_latent)
        x_p = sess.run(m.x_mean, {m.z: z_p})
        x_p = common.post_proc(x_p, config)
        common.save_image(common.batch_image(x_p),
                          os.path.join(image_path, 'sample_prior.png'))

        # Sample from priro, as Grid
        boundary = 2.0
        number_grid = 50
        blob = common.make_grid(boundary=boundary,
                                number_grid=number_grid,
                                dim_latent=n_latent)
        z_grid, dim_grid = blob.z_grid, blob.dim_grid
        x_grid = sess.run(m.x_mean, {m.z: z_grid})
        x_grid = common.post_proc(x_grid, config)
        batch_image_grid = common.make_batch_image_grid(dim_grid, number_grid)
        common.save_image(batch_image_grid(x_grid),
                          os.path.join(image_path, 'sample_grid.png'))

        # Reconstruction (image grouped by label)
        sample_count = 60
        sample_index = []
        for i in range(sample_count):
            sample_index.append(index_grouped_by_label[i % 10][i // 10])
        x_real = train_data[sample_index]
        mu, sigma = sess.run([m.mu, m.sigma], {m.x: x_real})
        x_rec = sess.run(m.x_mean, {m.mu: mu, m.sigma: sigma})
        x_rec = common.post_proc(x_rec, config)

        x_real = common.post_proc(x_real, config)
        common.save_image(common.batch_image(x_real),
                          os.path.join(image_path, 'image_real.png'))
        common.save_image(common.batch_image(x_rec),
                          os.path.join(image_path, 'image_rec.png'))