def load_dataset_wavegan(): """Load WaveGAN's dataset. The loaded dataset consists of: - original data (dataset_blob, train_data, train_label), - encoded data from a pretrained model (train_mu, train_sigma), and - index grouped by label (index_grouped_by_label). Some of these attributes are not avaiable (set as None) but are left here to keep everything aligned with returned value of `load_dataset`. Returns: An tuple of abovementioned components in the dataset. """ latent_dir = os.path.expanduser(FLAGS.wavegan_latent_dir) path_train = os.path.join(latent_dir, 'data_train.npz') train = np.load(path_train) train_z = train['z'] train_label = train['label'] index_grouped_by_label = common.get_index_grouped_by_label(train_label) dataset_blob, train_data = None, None train_mu, train_sigma = train_z, None return (dataset_blob, train_data, train_label, train_mu, train_sigma, index_grouped_by_label)
def load_dataset(config_name, exp_uid): """Load a dataset from a config's name. The loaded dataset consists of: - original data (dataset, train_data, train_label), - encoded data from a pretrained model (train_mu, train_sigma), and - index grouped by label (index_grouped_by_label). - path of saving (save_path) for restoring pre-trained models. Args: config_name: A string indicating the name of config to parameterize the model that associates with the dataset. exp_uid: A string representing the unique id of experiment to be used in model that associates with the dataset. Returns: A DatasetBlob of abovementioned components in the dataset. """ config = common.load_config(config_name) this_config_is_wavegan = common.config_is_wavegan(config) if this_config_is_wavegan: return load_dataset_wavegan() model_uid = common.get_model_uid(config_name, exp_uid) dataset = common.load_dataset(config) train_data = dataset.train_data attr_train = dataset.attr_train save_path = dataset.save_path path_train = join(dataset.basepath, 'encoded', model_uid, 'encoded_train_data.npz') train = np.load(path_train) train_mu = train['mu'] train_sigma = train['sigma'] train_label = np.argmax(attr_train, axis=-1) # from one-hot to label index_grouped_by_label = common.get_index_grouped_by_label(train_label) tf.logging.info('index_grouped_by_label size: %s', [len(_) for _ in index_grouped_by_label]) tf.logging.info('train loaded from %s', path_train) tf.logging.info('train shapes: mu = %s, sigma = %s', train_mu.shape, train_sigma.shape) return DatasetBlob( train_data=train_data, train_label=train_label, train_mu=train_mu, train_sigma=train_sigma, index_grouped_by_label=index_grouped_by_label, save_path=save_path, )
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module('configs.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) n_latent = config['n_latent'] # Load dataset dataset = common.load_dataset(config) basepath = dataset.basepath save_path = dataset.save_path train_data = dataset.train_data attr_train = dataset.attr_train train_label = np.argmax(attr_train, axis=-1) # from one-hot to label index_grouped_by_label = common.get_index_grouped_by_label(train_label) # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Load Model tf.reset_default_graph() sess = tf.Session() with tf.device(tf.train.replica_device_setter(ps_tasks=0)): m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Initialize sess.run(tf.global_variables_initializer()) # Load m.vae_saver.restore( sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid)) # Sample from prior sample_count = 60 image_path = os.path.join(basepath, 'sample', model_uid) tf.gfile.MakeDirs(image_path) # from prior z_p = np.random.randn(sample_count, m.n_latent) x_p = sess.run(m.x_mean, {m.z: z_p}) x_p = common.post_proc(x_p, config) common.save_image(common.batch_image(x_p), os.path.join(image_path, 'sample_prior.png')) # Sample from priro, as Grid boundary = 2.0 number_grid = 50 blob = common.make_grid(boundary=boundary, number_grid=number_grid, dim_latent=n_latent) z_grid, dim_grid = blob.z_grid, blob.dim_grid x_grid = sess.run(m.x_mean, {m.z: z_grid}) x_grid = common.post_proc(x_grid, config) batch_image_grid = common.make_batch_image_grid(dim_grid, number_grid) common.save_image(batch_image_grid(x_grid), os.path.join(image_path, 'sample_grid.png')) # Reconstruction (image grouped by label) sample_count = 60 sample_index = [] for i in range(sample_count): sample_index.append(index_grouped_by_label[i % 10][i // 10]) x_real = train_data[sample_index] mu, sigma = sess.run([m.mu, m.sigma], {m.x: x_real}) x_rec = sess.run(m.x_mean, {m.mu: mu, m.sigma: sigma}) x_rec = common.post_proc(x_rec, config) x_real = common.post_proc(x_real, config) common.save_image(common.batch_image(x_real), os.path.join(image_path, 'image_real.png')) common.save_image(common.batch_image(x_rec), os.path.join(image_path, 'image_rec.png'))