def prepare_dirs( signature='unspecified_signature', config_name='unspecified_config_name', exp_uid='unspecified_exp_uid', ): """Prepare saving and sampling direcotories for training. Args: signature: A string of signature of model such as `joint_model`. config_name: A string representing the name of config for joint model. exp_uid: A string representing the unique id of experiment to be used in joint model. Returns: A tuple of (save_dir, sample_dir). They are strings and are paths to the directory for saving checkpoints / summaries and path to the directory for saving samplings, respectively. """ model_uid = common.get_model_uid(config_name, exp_uid) local_base_path = os.path.join(common.get_default_scratch(), signature) save_dir = os.path.join(local_base_path, 'ckpts', model_uid) tf.gfile.MakeDirs(save_dir) sample_dir = os.path.join(local_base_path, 'sample', model_uid) tf.gfile.MakeDirs(sample_dir) return save_dir, sample_dir
def restore_model(saver, config_name, exp_uid, sess, save_path, ckpt_filename_template): model_uid = common.get_model_uid(config_name, exp_uid) saver.restore( sess, os.path.join( save_path, model_uid, 'best', ckpt_filename_template % model_uid))
def restore_model(saver, config_name, exp_uid, sess, save_path, ckpt_filename_template): model_uid = common.get_model_uid(config_name, exp_uid) saver.restore( sess, os.path.join(save_path, model_uid, 'best', ckpt_filename_template % model_uid))
def prepare_dirs( signature='unspecified_signature', config_name='unspecified_config_name', exp_uid='unspecified_exp_uid', ): """Prepare saving and sampling direcotories for training. Args: signature: A string of signature of model such as `joint_model`. config_name: A string representing the name of config for joint model. exp_uid: A string representing the unique id of experiment to be used in joint model. Returns: A tuple of (save_dir, sample_dir). They are strings and are paths to the directory for saving checkpoints / summaries and path to the directory for saving samplings, respectively. """ model_uid = common.get_model_uid(config_name, exp_uid) local_base_path = os.path.join(common.get_default_scratch(), signature) save_dir = os.path.join(local_base_path, 'ckpts', model_uid) tf.gfile.MakeDirs(save_dir) sample_dir = os.path.join(local_base_path, 'sample', model_uid) tf.gfile.MakeDirs(sample_dir) return save_dir, sample_dir
def load_dataset(config_name, exp_uid): """Load a dataset from a config's name. The loaded dataset consists of: - original data (dataset_blob, train_data, train_label), - encoded data from a pretrained model (train_mu, train_sigma), and - index grouped by label (index_grouped_by_label). Args: config_name: A string indicating the name of config to parameterize the model that associates with the dataset. exp_uid: A string representing the unique id of experiment to be used in model that associates with the dataset. Returns: An tuple of abovementioned components in the dataset. """ config = load_config(config_name) if config_is_wavegan(config): return load_dataset_wavegan() model_uid = common.get_model_uid(config_name, exp_uid) dataset = common.load_dataset(config) train_data = dataset.train_data attr_train = dataset.attr_train path_train = os.path.join(dataset.basepath, 'encoded', model_uid, 'encoded_train_data.npz') train = np.load(path_train) train_mu = train['mu'] train_sigma = train['sigma'] train_label = np.argmax(attr_train, axis=-1) # from one-hot to label index_grouped_by_label = common.get_index_grouped_by_label(train_label) tf.logging.info('index_grouped_by_label size: %s', [len(_) for _ in index_grouped_by_label]) tf.logging.info('train loaded from %s', path_train) tf.logging.info('train shapes: mu = %s, sigma = %s', train_mu.shape, train_sigma.shape) dataset_blob = dataset return (dataset_blob, train_data, train_label, train_mu, train_sigma, index_grouped_by_label)
def load_dataset(config_name, exp_uid): """Load a dataset from a config's name. The loaded dataset consists of: - original data (dataset_blob, train_data, train_label), - encoded data from a pretrained model (train_mu, train_sigma), and - index grouped by label (index_grouped_by_label). Args: config_name: A string indicating the name of config to parameterize the model that associates with the dataset. exp_uid: A string representing the unique id of experiment to be used in model that associates with the dataset. Returns: An tuple of abovementioned components in the dataset. """ config = load_config(config_name) if config_is_wavegan(config): return load_dataset_wavegan() model_uid = common.get_model_uid(config_name, exp_uid) dataset = common.load_dataset(config) train_data = dataset.train_data attr_train = dataset.attr_train path_train = os.path.join(dataset.basepath, 'encoded', model_uid, 'encoded_train_data.npz') train = np.load(path_train) train_mu = train['mu'] train_sigma = train['sigma'] train_label = np.argmax(attr_train, axis=-1) # from one-hot to label index_grouped_by_label = common.get_index_grouped_by_label(train_label) tf.logging.info('index_grouped_by_label size: %s', [len(_) for _ in index_grouped_by_label]) tf.logging.info('train loaded from %s', path_train) tf.logging.info('train shapes: mu = %s, sigma = %s', train_mu.shape, train_sigma.shape) dataset_blob = dataset return (dataset_blob, train_data, train_label, train_mu, train_sigma, index_grouped_by_label)
def load_model(model_cls, config_name, exp_uid): """Load a model. Args: model_cls: A sonnet Class that is the factory of model. config_name: A string indicating the name of config to parameterize the model. exp_uid: A string representing the unique id of experiment to be used in model. Returns: An instance of sonnet model. """ config = load_config(config_name) model_uid = common.get_model_uid(config_name, exp_uid) m = model_cls(config, name=model_uid) m() return m
def load_model(model_cls, config_name, exp_uid): """Load a model. Args: model_cls: A sonnet Class that is the factory of model. config_name: A string indicating the name of config to parameterize the model. exp_uid: A string representing the unique id of experiment to be used in model. Returns: An instance of sonnet model. """ config = load_config(config_name) model_uid = common.get_model_uid(config_name, exp_uid) m = model_cls(config, name=model_uid) m() return m
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module('configs.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) batch_size = config['batch_size'] # Load dataset dataset = common.load_dataset(config) basepath = dataset.basepath save_path = dataset.save_path train_data = dataset.train_data eval_data = dataset.eval_data # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) # Load Model tf.reset_default_graph() sess = tf.Session() m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Initialize sess.run(tf.global_variables_initializer()) # Load m.vae_saver.restore(sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid)) # Encode def encode(data): """Encode the data in dataspace to latent spaceself. This script runs the encoding in batched mode to limit GPU memory usage. Args: data: A numpy array of data to be encoded. Returns: A object with instances `mu` and `sigma`, the parameters of encoded distributions in the latent space. """ mu_list, sigma_list = [], [] for i in range(0, len(data), batch_size): start, end = i, min(i + batch_size, len(data)) batch = data[start:end] mu, sigma = sess.run([m.mu, m.sigma], {m.x: batch}) mu_list.append(mu) sigma_list.append(sigma) mu = np.concatenate(mu_list) sigma = np.concatenate(sigma_list) return common.ObjectBlob(mu=mu, sigma=sigma) encoded_train_data = encode(train_data) tf.logging.info( 'encode train_data: mu.shape = %s sigma.shape = %s', encoded_train_data.mu.shape, encoded_train_data.sigma.shape, ) encoded_eval_data = encode(eval_data) tf.logging.info( 'encode eval_data: mu.shape = %s sigma.shape = %s', encoded_eval_data.mu.shape, encoded_eval_data.sigma.shape, ) # Save encoded as npz file encoded_save_path = os.path.join(basepath, 'encoded', model_uid) tf.gfile.MakeDirs(encoded_save_path) tf.logging.info('encoded train_data saved to %s', os.path.join(encoded_save_path, 'encoded_train_data.npz')) np.savez( os.path.join(encoded_save_path, 'encoded_train_data.npz'), mu=encoded_train_data.mu, sigma=encoded_train_data.sigma, ) tf.logging.info('encoded eval_data saved to %s', os.path.join(encoded_save_path, 'encoded_eval_data.npz')) np.savez( os.path.join(encoded_save_path, 'encoded_eval_data.npz'), mu=encoded_eval_data.mu, sigma=encoded_eval_data.sigma, )
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module('configs.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) n_latent = config['n_latent'] # Load dataset dataset = common.load_dataset(config) basepath = dataset.basepath save_path = dataset.save_path train_data = dataset.train_data # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Load Model tf.reset_default_graph() sess = tf.Session() with tf.device(tf.train.replica_device_setter(ps_tasks=0)): m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Initialize sess.run(tf.global_variables_initializer()) # Load m.vae_saver.restore( sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid)) # Sample from prior sample_count = 64 image_path = os.path.join(basepath, 'sample', model_uid) tf.gfile.MakeDirs(image_path) # from prior z_p = np.random.randn(sample_count, m.n_latent) x_p = sess.run(m.x_mean, {m.z: z_p}) x_p = common.post_proc(x_p, config) common.save_image(common.batch_image(x_p), os.path.join(image_path, 'sample_prior.png')) # Sample from priro, as Grid boundary = 2.0 number_grid = 50 blob = common.make_grid(boundary=boundary, number_grid=number_grid, dim_latent=n_latent) z_grid, dim_grid = blob.z_grid, blob.dim_grid x_grid = sess.run(m.x_mean, {m.z: z_grid}) x_grid = common.post_proc(x_grid, config) batch_image_grid = common.make_batch_image_grid(dim_grid, number_grid) common.save_image(batch_image_grid(x_grid), os.path.join(image_path, 'sample_grid.png')) # Reconstruction sample_count = 64 x_real = train_data[:sample_count] mu, sigma = sess.run([m.mu, m.sigma], {m.x: x_real}) x_rec = sess.run(m.x_mean, {m.mu: mu, m.sigma: sigma}) x_rec = common.post_proc(x_rec, config) x_real = common.post_proc(x_real, config) common.save_image(common.batch_image(x_real), os.path.join(image_path, 'image_real.png')) common.save_image(common.batch_image(x_rec), os.path.join(image_path, 'image_rec.png'))
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module('configs.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) n_latent = config['n_latent'] # Load dataset dataset = common.load_dataset(config) basepath = dataset.basepath save_path = dataset.save_path train_data = dataset.train_data # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Load Model tf.reset_default_graph() sess = tf.Session() with tf.device(tf.train.replica_device_setter(ps_tasks=0)): m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Initialize sess.run(tf.global_variables_initializer()) # Load m.vae_saver.restore(sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid)) # Sample from prior sample_count = 64 image_path = os.path.join(basepath, 'sample', model_uid) tf.gfile.MakeDirs(image_path) # from prior z_p = np.random.randn(sample_count, m.n_latent) x_p = sess.run(m.x_mean, {m.z: z_p}) x_p = common.post_proc(x_p, config) common.save_image( common.batch_image(x_p), os.path.join(image_path, 'sample_prior.png')) # Sample from priro, as Grid boundary = 2.0 number_grid = 50 blob = common.make_grid( boundary=boundary, number_grid=number_grid, dim_latent=n_latent) z_grid, dim_grid = blob.z_grid, blob.dim_grid x_grid = sess.run(m.x_mean, {m.z: z_grid}) x_grid = common.post_proc(x_grid, config) batch_image_grid = common.make_batch_image_grid(dim_grid, number_grid) common.save_image( batch_image_grid(x_grid), os.path.join(image_path, 'sample_grid.png')) # Reconstruction sample_count = 64 x_real = train_data[:sample_count] mu, sigma = sess.run([m.mu, m.sigma], {m.x: x_real}) x_rec = sess.run(m.x_mean, {m.mu: mu, m.sigma: sigma}) x_rec = common.post_proc(x_rec, config) x_real = common.post_proc(x_real, config) common.save_image( common.batch_image(x_real), os.path.join(image_path, 'image_real.png')) common.save_image( common.batch_image(x_rec), os.path.join(image_path, 'image_rec.png'))
def main(unused_argv): # pylint:disable=unused-variable # Reason: # This training script relys on many programmatical call to function and # access to variables. Pylint cannot infer this case so it emits false alarm # of unused-variable if we do not disable this warning. # pylint:disable=invalid-name # Reason: # Following variables have their name consider to be invalid by pylint so # we disable the warning. # - Variable that in its name has A or B indicating their belonging of # one side of data. del unused_argv # Load main config config_name = FLAGS.config config = load_config(config_name) config_name_A = config['config_A'] config_name_B = config['config_B'] config_name_classifier_A = config['config_classifier_A'] config_name_classifier_B = config['config_classifier_B'] # Load dataset dataset_A = common_joint.load_dataset(config_name_A, FLAGS.exp_uid_A) (dataset_blob_A, train_data_A, train_label_A, train_mu_A, train_sigma_A, index_grouped_by_label_A) = dataset_A dataset_B = common_joint.load_dataset(config_name_B, FLAGS.exp_uid_B) (dataset_blob_B, train_data_B, train_label_B, train_mu_B, train_sigma_B, index_grouped_by_label_B) = dataset_B # Prepare directories dirs = common_joint.prepare_dirs('joint', config_name, FLAGS.exp_uid) save_dir, sample_dir = dirs # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Real Training. tf.reset_default_graph() sess = tf.Session() # Load model's architecture (= build) one_side_helper_A = common_joint.OneSideHelper(config_name_A, FLAGS.exp_uid_A, config_name_classifier_A, FLAGS.exp_uid_classifier) one_side_helper_B = common_joint.OneSideHelper(config_name_B, FLAGS.exp_uid_B, config_name_classifier_B, FLAGS.exp_uid_classifier) m = common_joint.load_model(model_joint.Model, config_name, FLAGS.exp_uid) # Initialize and restore sess.run(tf.global_variables_initializer()) one_side_helper_A.restore(dataset_blob_A) one_side_helper_B.restore(dataset_blob_B) # Restore from ckpt config_name = FLAGS.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) save_name = os.path.join( save_dir, 'transfer_%s_%d.ckpt' % (model_uid, FLAGS.load_ckpt_iter)) m.vae_saver.restore(sess, save_name) # prepare intepolate dir intepolate_dir = os.path.join( sample_dir, 'interpolate_sample', '%010d' % FLAGS.load_ckpt_iter) tf.gfile.MakeDirs(intepolate_dir) # things interpolate_labels = [int(_) for _ in FLAGS.interpolate_labels.split(',')] nb_images_between_labels = FLAGS.nb_images_between_labels index_list_A = [] last_pos = [0] * 10 for label in interpolate_labels: index_list_A.append(index_grouped_by_label_A[label][last_pos[label]]) last_pos[label] += 1 index_list_B = [] last_pos = [-1] * 10 for label in interpolate_labels: index_list_B.append(index_grouped_by_label_B[label][last_pos[label]]) last_pos[label] -= 1 z_A = [] z_A.append(train_mu_A[index_list_A[0]]) for i_label in range(1, len(interpolate_labels)): last_z_A = z_A[-1] this_z_A = train_mu_A[index_list_A[i_label]] for j in range(1, nb_images_between_labels + 1): z_A.append(last_z_A + (this_z_A - last_z_A) * (float(j) / nb_images_between_labels)) z_B = [] z_B.append(train_mu_B[index_list_B[0]]) for i_label in range(1, len(interpolate_labels)): last_z_B = z_B[-1] this_z_B = train_mu_B[index_list_B[i_label]] for j in range(1, nb_images_between_labels + 1): z_B.append(last_z_B + (this_z_B - last_z_B) * (float(j) / nb_images_between_labels)) z_B_tr = [] for this_z_A in z_A: this_z_B_tr = sess.run(m.x_A_to_B_direct, {m.x_A: np.array([this_z_A])}) z_B_tr.append(this_z_B_tr[0]) # Generate data domain instances and save. z_A = np.array(z_A) x_A = one_side_helper_A.m_helper.decode(z_A) x_A = common.post_proc(x_A, one_side_helper_A.m_helper.config) batched_x_A = common.batch_image( x_A, max_images=len(x_A), rows=len(x_A), cols=1, ) common.save_image(batched_x_A, os.path.join(intepolate_dir, 'x_A.png')) z_B = np.array(z_B) x_B = one_side_helper_B.m_helper.decode(z_B) x_B = common.post_proc(x_B, one_side_helper_B.m_helper.config) batched_x_B = common.batch_image( x_B, max_images=len(x_B), rows=len(x_B), cols=1, ) common.save_image(batched_x_B, os.path.join(intepolate_dir, 'x_B.png')) z_B_tr = np.array(z_B_tr) x_B_tr = one_side_helper_B.m_helper.decode(z_B_tr) x_B_tr = common.post_proc(x_B_tr, one_side_helper_B.m_helper.config) batched_x_B_tr = common.batch_image( x_B_tr, max_images=len(x_B_tr), rows=len(x_B_tr), cols=1, ) common.save_image(batched_x_B_tr, os.path.join(intepolate_dir, 'x_B_tr.png'))
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module(configs_module_prefix + '.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) batch_size = config['batch_size'] # Load dataset dataset = common.load_dataset(config) save_path = dataset.save_path train_data = dataset.train_data attr_train = dataset.attr_train eval_data = dataset.eval_data attr_eval = dataset.attr_eval # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) np.random.seed(FLAGS.random_seed) # We use `N` in variable name to emphasis its being the Number of something. N_train = train_data.shape[0] # pylint:disable=invalid-name N_eval = eval_data.shape[0] # pylint:disable=invalid-name # Load Model tf.reset_default_graph() sess = tf.Session() m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Create summaries tf.summary.scalar('Train_Loss', m.vae_loss) tf.summary.scalar('Mean_Recon_LL', m.mean_recons) tf.summary.scalar('Mean_KL', m.mean_KL) scalar_summaries = tf.summary.merge_all() x_mean_, x_ = m.x_mean, m.x if common.dataset_is_mnist_family(config['dataset']): x_mean_ = tf.reshape(x_mean_, [-1, MNIST_SIZE, MNIST_SIZE, 1]) x_ = tf.reshape(x_, [-1, MNIST_SIZE, MNIST_SIZE, 1]) x_mean_summary = tf.summary.image( 'Reconstruction', nn.tf_batch_image(x_mean_), max_outputs=1) x_summary = tf.summary.image('Original', nn.tf_batch_image(x_), max_outputs=1) sample_summary = tf.summary.image( 'Sample', nn.tf_batch_image(x_mean_), max_outputs=1) # Summary writers train_writer = tf.summary.FileWriter(save_dir + '/vae_train', sess.graph) eval_writer = tf.summary.FileWriter(save_dir + '/vae_eval', sess.graph) # Initialize sess.run(tf.global_variables_initializer()) i_start = 0 running_N_eval = 30 # pylint:disable=invalid-name traces = { 'i': [], 'i_pred': [], 'loss': [], 'loss_eval': [], } best_eval_loss = np.inf vae_lr_ = np.logspace(np.log10(FLAGS.lr), np.log10(1e-6), FLAGS.n_iters) # Train the VAE for i in range(i_start, FLAGS.n_iters): start = (i * batch_size) % N_train end = start + batch_size batch = train_data[start:end] labels = attr_train[start:end] # train op res = sess.run( [m.train_vae, m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries], { m.x: batch, m.vae_lr: vae_lr_[i], m.labels: labels, }) tf.logging.info('Iter: %d, Loss: %d', i, res[1]) train_writer.add_summary(res[-1], i) if i % FLAGS.n_iters_per_eval == 0: # write training reconstructions if batch.shape[0] == batch_size: res = sess.run([x_summary, x_mean_summary], { m.x: batch, m.labels: labels, }) train_writer.add_summary(res[0], i) train_writer.add_summary(res[1], i) # write sample reconstructions prior_sample = sess.run(m.prior_sample) res = sess.run([sample_summary], { m.q_z_sample: prior_sample, m.labels: labels, }) train_writer.add_summary(res[0], i) # write eval summaries start = (i * batch_size) % N_eval end = start + batch_size batch = eval_data[start:end] labels = attr_eval[start:end] if batch.shape[0] == batch_size: res_eval = sess.run([ m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries, x_summary, x_mean_summary ], { m.x: batch, m.labels: labels, }) traces['loss_eval'].append(res_eval[0]) eval_writer.add_summary(res_eval[-3], i) eval_writer.add_summary(res_eval[-2], i) eval_writer.add_summary(res_eval[-1], i) if i % FLAGS.n_iters_per_save == 0: smoothed_eval_loss = np.mean(traces['loss_eval'][-running_N_eval:]) if smoothed_eval_loss < best_eval_loss: # Save the best model best_eval_loss = smoothed_eval_loss save_name = os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid) tf.logging.info('SAVING BEST! %s Iter: %d', save_name, i) m.vae_saver.save(sess, save_name) with tf.gfile.Open(os.path.join(best_dir, 'best_ckpt_iters.txt'), 'w') as f: f.write('%d' % i)
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module(configs_module_prefix + '.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) batch_size = config['batch_size'] # Load dataset dataset = common.load_dataset(config) save_path = dataset.save_path train_data = dataset.train_data attr_train = dataset.attr_train eval_data = dataset.eval_data attr_eval = dataset.attr_eval # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) np.random.seed(10003) N_train = train_data.shape[0] N_eval = eval_data.shape[0] # Load Model tf.reset_default_graph() sess = tf.Session() m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Create summaries y_true = m.labels y_pred = tf.cast(tf.greater(m.pred_classifier, 0.5), tf.int32) accuracy = tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32)) tf.summary.scalar('Loss', m.classifier_loss) tf.summary.scalar('Accuracy', accuracy) scalar_summaries = tf.summary.merge_all() # Summary writers train_writer = tf.summary.FileWriter(save_dir + '/train', sess.graph) eval_writer = tf.summary.FileWriter(save_dir + '/eval', sess.graph) # Initialize sess.run(tf.global_variables_initializer()) i_start = 0 running_N_eval = 30 traces = { 'i': [], 'i_pred': [], 'loss': [], 'loss_eval': [], } best_eval_loss = np.inf classifier_lr_ = np.logspace( np.log10(FLAGS.lr), np.log10(1e-6), FLAGS.n_iters) # Train the Classifier for i in range(i_start, FLAGS.n_iters): start = (i * batch_size) % N_train end = start + batch_size batch = train_data[start:end] labels = attr_train[start:end] # train op res = sess.run([m.train_classifier, m.classifier_loss, scalar_summaries], { m.x: batch, m.labels: labels, m.classifier_lr: classifier_lr_[i] }) tf.logging.info('Iter: %d, Loss: %.2e', i, res[1]) train_writer.add_summary(res[-1], i) if i % 10 == 0: # write training reconstructions if batch.shape[0] == batch_size: # write eval summaries start = (i * batch_size) % N_eval end = start + batch_size batch = eval_data[start:end] labels = attr_eval[start:end] if batch.shape[0] == batch_size: res_eval = sess.run([m.classifier_loss, scalar_summaries], { m.x: batch, m.labels: labels, }) traces['loss_eval'].append(res_eval[0]) eval_writer.add_summary(res_eval[-1], i) if i % FLAGS.n_iters_per_save == 0: smoothed_eval_loss = np.mean(traces['loss_eval'][-running_N_eval:]) if smoothed_eval_loss < best_eval_loss: # Save the best model best_eval_loss = smoothed_eval_loss save_name = os.path.join(best_dir, 'classifier_best_%s.ckpt' % model_uid) tf.logging.info('SAVING BEST! %s Iter: %d', save_name, i) m.classifier_saver.save(sess, save_name) with tf.gfile.Open(os.path.join(best_dir, 'best_ckpt_iters.txt'), 'w') as f: f.write('%d' % i)
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module('configs.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) batch_size = config['batch_size'] # Load dataset dataset = common.load_dataset(config) basepath = dataset.basepath save_path = dataset.save_path train_data = dataset.train_data eval_data = dataset.eval_data # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) # Load Model tf.reset_default_graph() sess = tf.Session() m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Initialize sess.run(tf.global_variables_initializer()) # Load m.vae_saver.restore(sess, os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid)) # Encode def encode(data): """Encode the data in dataspace to latent spaceself. This script runs the encoding in batched mode to limit GPU memory usage. Args: data: A numpy array of data to be encoded. Returns: A object with instances `mu` and `sigma`, the parameters of encoded distributions in the latent space. """ mu_list, sigma_list = [], [] for i in range(0, len(data), batch_size): start, end = i, min(i + batch_size, len(data)) batch = data[start:end] mu, sigma = sess.run([m.mu, m.sigma], {m.x: batch}) mu_list.append(mu) sigma_list.append(sigma) mu = np.concatenate(mu_list) sigma = np.concatenate(sigma_list) return common.ObjectBlob(mu=mu, sigma=sigma) encoded_train_data = encode(train_data) tf.logging.info( 'encode train_data: mu.shape = %s sigma.shape = %s', encoded_train_data.mu.shape, encoded_train_data.sigma.shape, ) encoded_eval_data = encode(eval_data) tf.logging.info( 'encode eval_data: mu.shape = %s sigma.shape = %s', encoded_eval_data.mu.shape, encoded_eval_data.sigma.shape, ) # Save encoded as npz file encoded_save_path = os.path.join(basepath, 'encoded', model_uid) tf.gfile.MakeDirs(encoded_save_path) tf.logging.info('encoded train_data saved to %s', os.path.join(encoded_save_path, 'encoded_train_data.npz')) np.savez( os.path.join(encoded_save_path, 'encoded_train_data.npz'), mu=encoded_train_data.mu, sigma=encoded_train_data.sigma, ) tf.logging.info('encoded eval_data saved to %s', os.path.join(encoded_save_path, 'encoded_eval_data.npz')) np.savez( os.path.join(encoded_save_path, 'encoded_eval_data.npz'), mu=encoded_eval_data.mu, sigma=encoded_eval_data.sigma, )
def main(unused_argv): # pylint:disable=unused-variable # Reason: # This training script relys on many programmatical call to function and # access to variables. Pylint cannot infer this case so it emits false alarm # of unused-variable if we do not disable this warning. # pylint:disable=invalid-name # Reason: # Following variables have their name consider to be invalid by pylint so # we disable the warning. # - Variable that in its name has A or B indictating their belonging of # one side of data. del unused_argv # Load main config config_name = FLAGS.config config = load_config(config_name) config_name_A = config['config_A'] config_name_B = config['config_B'] config_name_classifier_A = config['config_classifier_A'] config_name_classifier_B = config['config_classifier_B'] # Load dataset dataset_A = common_joint.load_dataset(config_name_A, FLAGS.exp_uid_A) (dataset_blob_A, train_data_A, train_label_A, train_mu_A, train_sigma_A, index_grouped_by_label_A) = dataset_A dataset_B = common_joint.load_dataset(config_name_B, FLAGS.exp_uid_B) (dataset_blob_B, train_data_B, train_label_B, train_mu_B, train_sigma_B, index_grouped_by_label_B) = dataset_B # Prepare directories dirs = common_joint.prepare_dirs('joint', config_name, FLAGS.exp_uid) save_dir, sample_dir = dirs # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Real Training. tf.reset_default_graph() sess = tf.Session() # Load model's architecture (= build) one_side_helper_A = common_joint.OneSideHelper(config_name_A, FLAGS.exp_uid_A, config_name_classifier_A, FLAGS.exp_uid_classifier) one_side_helper_B = common_joint.OneSideHelper(config_name_B, FLAGS.exp_uid_B, config_name_classifier_B, FLAGS.exp_uid_classifier) m = common_joint.load_model(model_joint.Model, config_name, FLAGS.exp_uid) # Prepare summary train_writer = tf.summary.FileWriter(save_dir + '/transfer_train', sess.graph) scalar_summaries = tf.summary.merge([ tf.summary.scalar(key, value) for key, value in m.get_summary_kv_dict().items() ]) manual_summary_helper = common_joint.ManualSummaryHelper() # Initialize and restore sess.run(tf.global_variables_initializer()) one_side_helper_A.restore(dataset_blob_A) one_side_helper_B.restore(dataset_blob_B) # Miscs from config batch_size = config['batch_size'] n_latent_shared = config['n_latent_shared'] pairing_number = config['pairing_number'] n_latent_A = config['vae_A']['n_latent'] n_latent_B = config['vae_B']['n_latent'] i_start = 0 # Data iterators single_data_iterator_A = common_joint.SingleDataIterator( train_mu_A, train_sigma_A, batch_size) single_data_iterator_B = common_joint.SingleDataIterator( train_mu_B, train_sigma_B, batch_size) paired_data_iterator = common_joint.PairedDataIterator( train_mu_A, train_sigma_A, train_data_A, train_label_A, index_grouped_by_label_A, train_mu_B, train_sigma_B, train_data_B, train_label_B, index_grouped_by_label_B, pairing_number, batch_size) single_data_iterator_A_for_evaluation = common_joint.SingleDataIterator( train_mu_A, train_sigma_A, batch_size) single_data_iterator_B_for_evaluation = common_joint.SingleDataIterator( train_mu_B, train_sigma_B, batch_size) # Training loop n_iters = FLAGS.n_iters for i in tqdm(list(range(i_start, n_iters)), desc='training', unit=' batch'): # Prepare data for this batch # - Unsupervised (A) x_A, _ = next(single_data_iterator_A) x_B, _ = next(single_data_iterator_B) # - Supervised (aligning) x_align_A, x_align_B, align_debug_info = next(paired_data_iterator) real_x_align_A, real_x_align_B = align_debug_info # Run training op and write summary res = sess.run( [m.train_full, scalar_summaries], { m.x_A: x_A, m.x_B: x_B, m.x_align_A: x_align_A, m.x_align_B: x_align_B, }) train_writer.add_summary(res[-1], i) if i % FLAGS.n_iters_per_save == 0: # Save the model if instructed config_name = FLAGS.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) save_name = os.path.join(save_dir, 'transfer_%s_%d.ckpt' % (model_uid, i)) m.vae_saver.save(sess, save_name) with tf.gfile.Open(os.path.join(save_dir, 'ckpt_iters.txt'), 'w') as f: f.write('%d' % i) # Evaluate if instructed if i % FLAGS.n_iters_per_eval == 0: # Helper functions def joint_sample(sample_size): z_hat = np.random.randn(sample_size, n_latent_shared) return sess.run([m.x_joint_A, m.x_joint_B], { m.z_hat: z_hat, }) def get_x_from_prior_A(): return sess.run(m.x_from_prior_A) def get_x_from_prior_B(): return sess.run(m.x_from_prior_B) def get_x_from_posterior_A(): return next(single_data_iterator_A_for_evaluation)[0] def get_x_from_posterior_B(): return next(single_data_iterator_B_for_evaluation)[0] def get_x_prime_A(x_A): return sess.run(m.x_prime_A, {m.x_A: x_A}) def get_x_prime_B(x_B): return sess.run(m.x_prime_B, {m.x_B: x_B}) def transfer_A_to_B(x_A): return sess.run(m.x_A_to_B, {m.x_A: x_A}) def transfer_B_to_A(x_B): return sess.run(m.x_B_to_A, {m.x_B: x_B}) def manual_summary(key, value): summary = manual_summary_helper.get_summary(sess, key, value) # This [cell-var-from-loop] is intented train_writer.add_summary(summary, i) # pylint: disable=cell-var-from-loop # Classifier based evaluation sample_total_size = 10000 sample_batch_size = 100 def pred(one_side_helper, x): real_x = six.ensure_text(one_side_helper.m_helper, x) return one_side_helper.m_classifier_helper.classify( real_x, batch_size) def accuarcy(x_1, x_2, type_1, type_2): assert type_1 in ('A', 'B') and type_2 in ('A', 'B') func_A = partial(pred, one_side_helper=one_side_helper_A) func_B = partial(pred, one_side_helper=one_side_helper_B) func_1 = func_A if type_1 == 'A' else func_B func_2 = func_A if type_2 == 'A' else func_B pred_1, pred_2 = func_1(x=x_1), func_2(x=x_2) return np.mean(np.equal(pred_1, pred_2).astype('f')) def joint_sample_accuarcy(): x_A, x_B = joint_sample(sample_size=sample_total_size) # pylint: disable=cell-var-from-loop return accuarcy(x_A, x_B, 'A', 'B') def transfer_sample_accuarcy_A_B(): x_A = get_x_from_prior_A() x_B = transfer_A_to_B(x_A) return accuarcy(x_A, x_B, 'A', 'B') def transfer_sample_accuarcy_B_A(): x_B = get_x_from_prior_B() x_A = transfer_B_to_A(x_B) return accuarcy(x_A, x_B, 'A', 'B') def transfer_accuarcy_A_B(): x_A = get_x_from_posterior_A() x_B = transfer_A_to_B(x_A) return accuarcy(x_A, x_B, 'A', 'B') def transfer_accuarcy_B_A(): x_B = get_x_from_posterior_B() x_A = transfer_B_to_A(x_B) return accuarcy(x_A, x_B, 'A', 'B') def recons_accuarcy_A(): # Use x_A in outer scope # These [cell-var-from-loop]s are intended x_A_prime = get_x_prime_A(x_A) # pylint: disable=cell-var-from-loop return accuarcy(x_A, x_A_prime, 'A', 'A') # pylint: disable=cell-var-from-loop def recons_accuarcy_B(): # use x_B in outer scope # These [cell-var-from-loop]s are intended x_B_prime = get_x_prime_B(x_B) # pylint: disable=cell-var-from-loop return accuarcy(x_B, x_B_prime, 'B', 'B') # pylint: disable=cell-var-from-loop # Do all manual summary for func_name in ( 'joint_sample_accuarcy', 'transfer_sample_accuarcy_A_B', 'transfer_sample_accuarcy_B_A', 'transfer_accuarcy_A_B', 'transfer_accuarcy_B_A', 'recons_accuarcy_A', 'recons_accuarcy_B', ): func = locals()[func_name] manual_summary(func_name, func()) # Sampling based evaluation / sampling x_prime_A = get_x_prime_A(x_A) x_prime_B = get_x_prime_B(x_B) x_from_prior_A = get_x_from_prior_A() x_from_prior_B = get_x_from_prior_B() x_A_to_B = transfer_A_to_B(x_A) x_B_to_A = transfer_B_to_A(x_B) x_align_A_to_B = transfer_A_to_B(x_align_A) x_align_B_to_A = transfer_B_to_A(x_align_B) x_joint_A, x_joint_B = joint_sample(sample_size=batch_size) this_iter_sample_dir = os.path.join(sample_dir, 'transfer_train_sample', '%010d' % i) tf.gfile.MakeDirs(this_iter_sample_dir) for helper, var_names, x_is_real_x in [ (one_side_helper_A.m_helper, ('x_A', 'x_prime_A', 'x_from_prior_A', 'x_B_to_A', 'x_align_A', 'x_align_B_to_A', 'x_joint_A'), False), (one_side_helper_A.m_helper, ('real_x_align_A', ), True), (one_side_helper_B.m_helper, ('x_B', 'x_prime_B', 'x_from_prior_B', 'x_A_to_B', 'x_align_B', 'x_align_A_to_B', 'x_joint_B'), False), (one_side_helper_B.m_helper, ('real_x_align_B', ), True), ]: for var_name in var_names: # Here `var` would be None if # - there is no such variable in `locals()`, or # - such variable exists but the value is None # In both case, we would skip saving data from it. var = locals().get(var_name, None) if var is not None: helper.save_data(var, var_name, this_iter_sample_dir, x_is_real_x)
def main(unused_argv): # pylint:disable=unused-variable # Reason: # This training script relys on many programmatical call to function and # access to variables. Pylint cannot infer this case so it emits false alarm # of unused-variable if we do not disable this warning. # pylint:disable=invalid-name # Reason: # Following variables have their name consider to be invalid by pylint so # we disable the warning. # - Variable that in its name has A or B indictating their belonging of # one side of data. del unused_argv # Load main config config_name = FLAGS.config config = load_config(config_name) config_name_A = config['config_A'] config_name_B = config['config_B'] config_name_classifier_A = config['config_classifier_A'] config_name_classifier_B = config['config_classifier_B'] # Load dataset dataset_A = common_joint.load_dataset(config_name_A, FLAGS.exp_uid_A) (dataset_blob_A, train_data_A, train_label_A, train_mu_A, train_sigma_A, index_grouped_by_label_A) = dataset_A dataset_B = common_joint.load_dataset(config_name_B, FLAGS.exp_uid_B) (dataset_blob_B, train_data_B, train_label_B, train_mu_B, train_sigma_B, index_grouped_by_label_B) = dataset_B # Prepare directories dirs = common_joint.prepare_dirs('joint', config_name, FLAGS.exp_uid) save_dir, sample_dir = dirs # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Real Training. tf.reset_default_graph() sess = tf.Session() # Load model's architecture (= build) one_side_helper_A = common_joint.OneSideHelper(config_name_A, FLAGS.exp_uid_A, config_name_classifier_A, FLAGS.exp_uid_classifier) one_side_helper_B = common_joint.OneSideHelper(config_name_B, FLAGS.exp_uid_B, config_name_classifier_B, FLAGS.exp_uid_classifier) m = common_joint.load_model(model_joint.Model, config_name, FLAGS.exp_uid) # Initialize and restore sess.run(tf.global_variables_initializer()) one_side_helper_A.restore(dataset_blob_A) one_side_helper_B.restore(dataset_blob_B) # Restore from ckpt config_name = FLAGS.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) save_name = join(save_dir, 'transfer_%s_%d.ckpt' % (model_uid, FLAGS.load_ckpt_iter)) m.vae_saver.restore(sess, save_name) # prepare intepolate dir intepolate_dir = join(sample_dir, 'interpolate_sample', '%010d' % FLAGS.load_ckpt_iter) tf.gfile.MakeDirs(intepolate_dir) # things interpolate_labels = [int(_) for _ in FLAGS.interpolate_labels.split(',')] nb_images_between_labels = FLAGS.nb_images_between_labels index_list_A = [] last_pos = [0] * 10 for label in interpolate_labels: index_list_A.append(index_grouped_by_label_A[label][last_pos[label]]) last_pos[label] += 1 index_list_B = [] last_pos = [-1] * 10 for label in interpolate_labels: index_list_B.append(index_grouped_by_label_B[label][last_pos[label]]) last_pos[label] -= 1 z_A = [] z_A.append(train_mu_A[index_list_A[0]]) for i_label in range(1, len(interpolate_labels)): last_z_A = z_A[-1] this_z_A = train_mu_A[index_list_A[i_label]] for j in range(1, nb_images_between_labels + 1): z_A.append(last_z_A + (this_z_A - last_z_A) * (float(j) / nb_images_between_labels)) z_B = [] z_B.append(train_mu_B[index_list_B[0]]) for i_label in range(1, len(interpolate_labels)): last_z_B = z_B[-1] this_z_B = train_mu_B[index_list_B[i_label]] for j in range(1, nb_images_between_labels + 1): z_B.append(last_z_B + (this_z_B - last_z_B) * (float(j) / nb_images_between_labels)) z_B_tr = [] for this_z_A in z_A: this_z_B_tr = sess.run(m.x_A_to_B_direct, {m.x_A: np.array([this_z_A])}) z_B_tr.append(this_z_B_tr[0]) # Generate data domain instances and save. z_A = np.array(z_A) x_A = one_side_helper_A.m_helper.decode(z_A) x_A = common.post_proc(x_A, one_side_helper_A.m_helper.config) batched_x_A = common.batch_image( x_A, max_images=len(x_A), rows=len(x_A), cols=1, ) common.save_image(batched_x_A, join(intepolate_dir, 'x_A.png')) z_B = np.array(z_B) x_B = one_side_helper_B.m_helper.decode(z_B) x_B = common.post_proc(x_B, one_side_helper_B.m_helper.config) batched_x_B = common.batch_image( x_B, max_images=len(x_B), rows=len(x_B), cols=1, ) common.save_image(batched_x_B, join(intepolate_dir, 'x_B.png')) z_B_tr = np.array(z_B_tr) x_B_tr = one_side_helper_B.m_helper.decode(z_B_tr) x_B_tr = common.post_proc(x_B_tr, one_side_helper_B.m_helper.config) batched_x_B_tr = common.batch_image( x_B_tr, max_images=len(x_B_tr), rows=len(x_B_tr), cols=1, ) common.save_image(batched_x_B_tr, join(intepolate_dir, 'x_B_tr.png'))
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module(configs_module_prefix + '.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) batch_size = config['batch_size'] # Load dataset dataset = common.load_dataset(config) save_path = dataset.save_path train_data = dataset.train_data attr_train = dataset.attr_train eval_data = dataset.eval_data attr_eval = dataset.attr_eval # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) np.random.seed(10003) N_train = train_data.shape[0] N_eval = eval_data.shape[0] # Load Model tf.reset_default_graph() sess = tf.Session() m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Create summaries y_true = m.labels y_pred = tf.cast(tf.greater(m.pred_classifier, 0.5), tf.int32) accuracy = tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32)) tf.summary.scalar('Loss', m.classifier_loss) tf.summary.scalar('Accuracy', accuracy) scalar_summaries = tf.summary.merge_all() # Summary writers train_writer = tf.summary.FileWriter(save_dir + '/train', sess.graph) eval_writer = tf.summary.FileWriter(save_dir + '/eval', sess.graph) # Initialize sess.run(tf.global_variables_initializer()) i_start = 0 running_N_eval = 30 traces = { 'i': [], 'i_pred': [], 'loss': [], 'loss_eval': [], } best_eval_loss = np.inf classifier_lr_ = np.logspace(np.log10(FLAGS.lr), np.log10(1e-6), FLAGS.n_iters) # Train the Classifier for i in range(i_start, FLAGS.n_iters): start = (i * batch_size) % N_train end = start + batch_size batch = train_data[start:end] labels = attr_train[start:end] # train op res = sess.run( [m.train_classifier, m.classifier_loss, scalar_summaries], { m.x: batch, m.labels: labels, m.classifier_lr: classifier_lr_[i] }) tf.logging.info('Iter: %d, Loss: %.2e', i, res[1]) train_writer.add_summary(res[-1], i) if i % 10 == 0: # write training reconstructions if batch.shape[0] == batch_size: # write eval summaries start = (i * batch_size) % N_eval end = start + batch_size batch = eval_data[start:end] labels = attr_eval[start:end] if batch.shape[0] == batch_size: res_eval = sess.run([m.classifier_loss, scalar_summaries], { m.x: batch, m.labels: labels, }) traces['loss_eval'].append(res_eval[0]) eval_writer.add_summary(res_eval[-1], i) if i % FLAGS.n_iters_per_save == 0: smoothed_eval_loss = np.mean(traces['loss_eval'][-running_N_eval:]) if smoothed_eval_loss < best_eval_loss: # Save the best model best_eval_loss = smoothed_eval_loss save_name = os.path.join(best_dir, 'classifier_best_%s.ckpt' % model_uid) tf.logging.info('SAVING BEST! %s Iter: %d', save_name, i) m.classifier_saver.save(sess, save_name) with tf.gfile.Open( os.path.join(best_dir, 'best_ckpt_iters.txt'), 'w') as f: f.write('%d' % i)
def main(unused_argv): del unused_argv # Load Config config_name = FLAGS.config config_module = importlib.import_module(configs_module_prefix + '.%s' % config_name) config = config_module.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) batch_size = config['batch_size'] # Load dataset dataset = common.load_dataset(config) save_path = dataset.save_path train_data = dataset.train_data attr_train = dataset.attr_train eval_data = dataset.eval_data attr_eval = dataset.attr_eval # Make the directory save_dir = os.path.join(save_path, model_uid) best_dir = os.path.join(save_dir, 'best') tf.gfile.MakeDirs(save_dir) tf.gfile.MakeDirs(best_dir) tf.logging.info('Save Dir: %s', save_dir) np.random.seed(FLAGS.random_seed) # We use `N` in variable name to emphasis its being the Number of something. N_train = train_data.shape[0] # pylint:disable=invalid-name N_eval = eval_data.shape[0] # pylint:disable=invalid-name # Load Model tf.reset_default_graph() sess = tf.Session() m = model_dataspace.Model(config, name=model_uid) _ = m() # noqa # Create summaries tf.summary.scalar('Train_Loss', m.vae_loss) tf.summary.scalar('Mean_Recon_LL', m.mean_recons) tf.summary.scalar('Mean_KL', m.mean_KL) scalar_summaries = tf.summary.merge_all() x_mean_, x_ = m.x_mean, m.x if common.dataset_is_mnist_family(config['dataset']): x_mean_ = tf.reshape(x_mean_, [-1, MNIST_SIZE, MNIST_SIZE, 1]) x_ = tf.reshape(x_, [-1, MNIST_SIZE, MNIST_SIZE, 1]) x_mean_summary = tf.summary.image('Reconstruction', nn.tf_batch_image(x_mean_), max_outputs=1) x_summary = tf.summary.image('Original', nn.tf_batch_image(x_), max_outputs=1) sample_summary = tf.summary.image('Sample', nn.tf_batch_image(x_mean_), max_outputs=1) # Summary writers train_writer = tf.summary.FileWriter(save_dir + '/vae_train', sess.graph) eval_writer = tf.summary.FileWriter(save_dir + '/vae_eval', sess.graph) # Initialize sess.run(tf.global_variables_initializer()) i_start = 0 running_N_eval = 30 # pylint:disable=invalid-name traces = { 'i': [], 'i_pred': [], 'loss': [], 'loss_eval': [], } best_eval_loss = np.inf vae_lr_ = np.logspace(np.log10(FLAGS.lr), np.log10(1e-6), FLAGS.n_iters) # Train the VAE for i in range(i_start, FLAGS.n_iters): start = (i * batch_size) % N_train end = start + batch_size batch = train_data[start:end] labels = attr_train[start:end] # train op res = sess.run([ m.train_vae, m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries ], { m.x: batch, m.vae_lr: vae_lr_[i], m.labels: labels, }) tf.logging.info('Iter: %d, Loss: %d', i, res[1]) train_writer.add_summary(res[-1], i) if i % FLAGS.n_iters_per_eval == 0: # write training reconstructions if batch.shape[0] == batch_size: res = sess.run([x_summary, x_mean_summary], { m.x: batch, m.labels: labels, }) train_writer.add_summary(res[0], i) train_writer.add_summary(res[1], i) # write sample reconstructions prior_sample = sess.run(m.prior_sample) res = sess.run([sample_summary], { m.q_z_sample: prior_sample, m.labels: labels, }) train_writer.add_summary(res[0], i) # write eval summaries start = (i * batch_size) % N_eval end = start + batch_size batch = eval_data[start:end] labels = attr_eval[start:end] if batch.shape[0] == batch_size: res_eval = sess.run([ m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries, x_summary, x_mean_summary ], { m.x: batch, m.labels: labels, }) traces['loss_eval'].append(res_eval[0]) eval_writer.add_summary(res_eval[-3], i) eval_writer.add_summary(res_eval[-2], i) eval_writer.add_summary(res_eval[-1], i) if i % FLAGS.n_iters_per_save == 0: smoothed_eval_loss = np.mean(traces['loss_eval'][-running_N_eval:]) if smoothed_eval_loss < best_eval_loss: # Save the best model best_eval_loss = smoothed_eval_loss save_name = os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid) tf.logging.info('SAVING BEST! %s Iter: %d', save_name, i) m.vae_saver.save(sess, save_name) with tf.gfile.Open( os.path.join(best_dir, 'best_ckpt_iters.txt'), 'w') as f: f.write('%d' % i)
def main(unused_argv): # pylint:disable=unused-variable # Reason: # This training script relys on many programmatical call to function and # access to variables. Pylint cannot infer this case so it emits false alarm # of unused-variable if we do not disable this warning. # pylint:disable=invalid-name # Reason: # Following variables have their name consider to be invalid by pylint so # we disable the warning. # - Variable that in its name has A or B indictating their belonging of # one side of data. del unused_argv # Load main config config_name = FLAGS.config config = load_config(config_name) config_name_A = config['config_A'] config_name_B = config['config_B'] config_name_classifier_A = config['config_classifier_A'] config_name_classifier_B = config['config_classifier_B'] # Load dataset dataset_A = common_joint.load_dataset(config_name_A, FLAGS.exp_uid_A) (dataset_blob_A, train_data_A, train_label_A, train_mu_A, train_sigma_A, index_grouped_by_label_A) = dataset_A dataset_B = common_joint.load_dataset(config_name_B, FLAGS.exp_uid_B) (dataset_blob_B, train_data_B, train_label_B, train_mu_B, train_sigma_B, index_grouped_by_label_B) = dataset_B # Prepare directories dirs = common_joint.prepare_dirs('joint', config_name, FLAGS.exp_uid) save_dir, sample_dir = dirs # Set random seed np.random.seed(FLAGS.random_seed) tf.set_random_seed(FLAGS.random_seed) # Real Training. tf.reset_default_graph() sess = tf.Session() # Load model's architecture (= build) one_side_helper_A = common_joint.OneSideHelper(config_name_A, FLAGS.exp_uid_A, config_name_classifier_A, FLAGS.exp_uid_classifier) one_side_helper_B = common_joint.OneSideHelper(config_name_B, FLAGS.exp_uid_B, config_name_classifier_B, FLAGS.exp_uid_classifier) m = common_joint.load_model(model_joint.Model, config_name, FLAGS.exp_uid) # Prepare summary train_writer = tf.summary.FileWriter(save_dir + '/transfer_train', sess.graph) scalar_summaries = tf.summary.merge([ tf.summary.scalar(key, value) for key, value in m.get_summary_kv_dict().items() ]) manual_summary_helper = common_joint.ManualSummaryHelper() # Initialize and restore sess.run(tf.global_variables_initializer()) one_side_helper_A.restore(dataset_blob_A) one_side_helper_B.restore(dataset_blob_B) # Miscs from config batch_size = config['batch_size'] n_latent_shared = config['n_latent_shared'] pairing_number = config['pairing_number'] n_latent_A = config['vae_A']['n_latent'] n_latent_B = config['vae_B']['n_latent'] i_start = 0 # Data iterators single_data_iterator_A = common_joint.SingleDataIterator( train_mu_A, train_sigma_A, batch_size) single_data_iterator_B = common_joint.SingleDataIterator( train_mu_B, train_sigma_B, batch_size) paired_data_iterator = common_joint.PairedDataIterator( train_mu_A, train_sigma_A, train_data_A, train_label_A, index_grouped_by_label_A, train_mu_B, train_sigma_B, train_data_B, train_label_B, index_grouped_by_label_B, pairing_number, batch_size) single_data_iterator_A_for_evaluation = common_joint.SingleDataIterator( train_mu_A, train_sigma_A, batch_size) single_data_iterator_B_for_evaluation = common_joint.SingleDataIterator( train_mu_B, train_sigma_B, batch_size) # Training loop n_iters = FLAGS.n_iters for i in tqdm(range(i_start, n_iters), desc='training', unit=' batch'): # Prepare data for this batch # - Unsupervised (A) x_A, _ = next(single_data_iterator_A) x_B, _ = next(single_data_iterator_B) # - Supervised (aligning) x_align_A, x_align_B, align_debug_info = next(paired_data_iterator) real_x_align_A, real_x_align_B = align_debug_info # Run training op and write summary res = sess.run([m.train_full, scalar_summaries], { m.x_A: x_A, m.x_B: x_B, m.x_align_A: x_align_A, m.x_align_B: x_align_B, }) train_writer.add_summary(res[-1], i) if i % FLAGS.n_iters_per_save == 0: # Save the model if instructed config_name = FLAGS.config model_uid = common.get_model_uid(config_name, FLAGS.exp_uid) save_name = os.path.join(save_dir, 'transfer_%s_%d.ckpt' % (model_uid, i)) m.vae_saver.save(sess, save_name) with tf.gfile.Open(os.path.join(save_dir, 'ckpt_iters.txt'), 'w') as f: f.write('%d' % i) # Evaluate if instructed if i % FLAGS.n_iters_per_eval == 0: # Helper functions def joint_sample(sample_size): z_hat = np.random.randn(sample_size, n_latent_shared) return sess.run([m.x_joint_A, m.x_joint_B], { m.z_hat: z_hat, }) def get_x_from_prior_A(): return sess.run(m.x_from_prior_A) def get_x_from_prior_B(): return sess.run(m.x_from_prior_B) def get_x_from_posterior_A(): return next(single_data_iterator_A_for_evaluation)[0] def get_x_from_posterior_B(): return next(single_data_iterator_B_for_evaluation)[0] def get_x_prime_A(x_A): return sess.run(m.x_prime_A, {m.x_A: x_A}) def get_x_prime_B(x_B): return sess.run(m.x_prime_B, {m.x_B: x_B}) def transfer_A_to_B(x_A): return sess.run(m.x_A_to_B, {m.x_A: x_A}) def transfer_B_to_A(x_B): return sess.run(m.x_B_to_A, {m.x_B: x_B}) def manual_summary(key, value): summary = manual_summary_helper.get_summary(sess, key, value) # This [cell-var-from-loop] is intented train_writer.add_summary(summary, i) # pylint: disable=cell-var-from-loop # Classifier based evaluation sample_total_size = 10000 sample_batch_size = 100 def pred(one_side_helper, x): real_x = one_side_helper.m_helper.decode(x) return one_side_helper.m_classifier_helper.classify(real_x, batch_size) def accuarcy(x_1, x_2, type_1, type_2): assert type_1 in ('A', 'B') and type_2 in ('A', 'B') func_A = partial(pred, one_side_helper=one_side_helper_A) func_B = partial(pred, one_side_helper=one_side_helper_B) func_1 = func_A if type_1 == 'A' else func_B func_2 = func_A if type_2 == 'A' else func_B pred_1, pred_2 = func_1(x=x_1), func_2(x=x_2) return np.mean(np.equal(pred_1, pred_2).astype('f')) def joint_sample_accuarcy(): x_A, x_B = joint_sample(sample_size=sample_total_size) # pylint: disable=cell-var-from-loop return accuarcy(x_A, x_B, 'A', 'B') def transfer_sample_accuarcy_A_B(): x_A = get_x_from_prior_A() x_B = transfer_A_to_B(x_A) return accuarcy(x_A, x_B, 'A', 'B') def transfer_sample_accuarcy_B_A(): x_B = get_x_from_prior_B() x_A = transfer_B_to_A(x_B) return accuarcy(x_A, x_B, 'A', 'B') def transfer_accuarcy_A_B(): x_A = get_x_from_posterior_A() x_B = transfer_A_to_B(x_A) return accuarcy(x_A, x_B, 'A', 'B') def transfer_accuarcy_B_A(): x_B = get_x_from_posterior_B() x_A = transfer_B_to_A(x_B) return accuarcy(x_A, x_B, 'A', 'B') def recons_accuarcy_A(): # Use x_A in outer scope # These [cell-var-from-loop]s are intended x_A_prime = get_x_prime_A(x_A) # pylint: disable=cell-var-from-loop return accuarcy(x_A, x_A_prime, 'A', 'A') # pylint: disable=cell-var-from-loop def recons_accuarcy_B(): # use x_B in outer scope # These [cell-var-from-loop]s are intended x_B_prime = get_x_prime_B(x_B) # pylint: disable=cell-var-from-loop return accuarcy(x_B, x_B_prime, 'B', 'B') # pylint: disable=cell-var-from-loop # Do all manual summary for func_name in ( 'joint_sample_accuarcy', 'transfer_sample_accuarcy_A_B', 'transfer_sample_accuarcy_B_A', 'transfer_accuarcy_A_B', 'transfer_accuarcy_B_A', 'recons_accuarcy_A', 'recons_accuarcy_B', ): func = locals()[func_name] manual_summary(func_name, func()) # Sampling based evaluation / sampling x_prime_A = get_x_prime_A(x_A) x_prime_B = get_x_prime_B(x_B) x_from_prior_A = get_x_from_prior_A() x_from_prior_B = get_x_from_prior_B() x_A_to_B = transfer_A_to_B(x_A) x_B_to_A = transfer_B_to_A(x_B) x_align_A_to_B = transfer_A_to_B(x_align_A) x_align_B_to_A = transfer_B_to_A(x_align_B) x_joint_A, x_joint_B = joint_sample(sample_size=batch_size) this_iter_sample_dir = os.path.join( sample_dir, 'transfer_train_sample', '%010d' % i) tf.gfile.MakeDirs(this_iter_sample_dir) for helper, var_names, x_is_real_x in [ (one_side_helper_A.m_helper, ('x_A', 'x_prime_A', 'x_from_prior_A', 'x_B_to_A', 'x_align_A', 'x_align_B_to_A', 'x_joint_A'), False), (one_side_helper_A.m_helper, ('real_x_align_A',), True), (one_side_helper_B.m_helper, ('x_B', 'x_prime_B', 'x_from_prior_B', 'x_A_to_B', 'x_align_B', 'x_align_A_to_B', 'x_joint_B'), False), (one_side_helper_B.m_helper, ('real_x_align_B',), True), ]: for var_name in var_names: # Here `var` would be None if # - there is no such variable in `locals()`, or # - such variable exists but the value is None # In both case, we would skip saving data from it. var = locals().get(var_name, None) if var is not None: helper.save_data(var, var_name, this_iter_sample_dir, x_is_real_x)