def gen_matrices(hparams): """...""" tf.compat.v1.logging.info('Generating Matrices.') # modified body of make_estimator(hparams) discriminator = _get_discriminator(hparams) generator = _get_generator_to_be_conditioned(hparams) if hparams.tpu_params.use_tpu_estimator: config = est_lib.get_tpu_run_config_from_hparams(hparams) estimator = est_lib.get_tpu_estimator(generator, discriminator, hparams, config) else: config = est_lib.get_run_config_from_hparams(hparams) estimator = est_lib.get_gpu_estimator(generator, discriminator, hparams, config) ckpt_str = evaluation.latest_checkpoint(hparams.model_dir) tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str) # saving matrices save_dir = os.environ['HOME'] if flags.FLAGS.use_tpu else hparams.model_dir embedding_map = estimator.get_variable_value( 'Discriminator/discriminator/d_embedding/embedding_map') np.save( '%s/embedding_map_step_%s.npy' % (save_dir, ckpt_str.split('-')[-1]), embedding_map) class_kernel = 'Discriminator/discriminator/d_sn_linear_class/dense/kernel' if class_kernel in estimator.get_variable_names(): classification_map = estimator.get_variable_value(class_kernel) np.save( '%s/classification_map_step_%s.npy' % (save_dir, ckpt_str.split('-')[-1]), classification_map)
def run_intra_fid_eval(hparams): """...""" tf.compat.v1.logging.info('Intra FID evaluation.') # modified body of make_estimator(hparams) generator = _get_generator_to_be_conditioned(hparams) discriminator = _get_discriminator(hparams) if hparams.tpu_params.use_tpu_estimator: config = est_lib.get_tpu_run_config_from_hparams(hparams) estimator = est_lib.get_tpu_estimator(generator, discriminator, hparams, config) else: config = est_lib.get_run_config_from_hparams(hparams) estimator = est_lib.get_gpu_estimator(generator, discriminator, hparams, config) ckpt_str = evaluation.latest_checkpoint(hparams.model_dir) tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str) chunk_sz = flags.FLAGS.intra_fid_eval_chunk_size n_chunks = flags.FLAGS.num_classes // chunk_sz for chunk_i in range(0, n_chunks): restrict_classes = list( range(chunk_i * chunk_sz, (chunk_i + 1) * chunk_sz)) limited_class_train_eval_input_fn = functools.partial( train_eval_input_fn, restrict_classes=restrict_classes, shift_classes=chunk_i * chunk_sz) eval_results = estimator.evaluate(limited_class_train_eval_input_fn, steps=hparams.num_eval_steps, name='eval_intra_fid') tf.compat.v1.logging.info( 'Finished intra fid {}/{} evaluation checkpoint: {}. IFID: {}'. format(chunk_i, n_chunks, ckpt_str, eval_results['eval/intra_fid']))
def gen_images(hparams): """...""" tf.compat.v1.logging.info('Generating Images.') # modified body of make_estimator(hparams) discriminator = _get_discriminator(hparams) generator = _get_generator_to_be_conditioned(hparams) if hparams.tpu_params.use_tpu_estimator: config = est_lib.get_tpu_run_config_from_hparams(hparams) estimator = est_lib.get_tpu_estimator(generator, discriminator, hparams, config) else: config = est_lib.get_run_config_from_hparams(hparams) estimator = est_lib.get_gpu_estimator(generator, discriminator, hparams, config) # tf.compat.v1.logging.info('Counting params...') # total_parameters = 0 # for variable in estimator.get_variable_names(): # vval = estimator.get_variable_value(variable) # nparam = np.prod(estimator.get_variable_value(variable).shape) # total_parameters += int(nparam) # tf.compat.v1.logging.info('Found %i params.' % total_parameters) # print(total_parameters) ckpt_str = evaluation.latest_checkpoint(hparams.model_dir) tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str) try: cur_step = int(estimator.get_variable_value('global_step')) except ValueError: cur_step = 0 eval_lib.predict_and_write_images(estimator, train_eval_input_fn, hparams.model_dir, 'step_%i' % cur_step)
def make_estimator(hparams): """Creates a TPU Estimator.""" generator = _get_generator(hparams) discriminator = _get_discriminator(hparams) if hparams.tpu_params.use_tpu_estimator: config = est_lib.get_tpu_run_config_from_hparams(hparams) return est_lib.get_tpu_estimator(generator, discriminator, hparams, config) else: config = est_lib.get_run_config_from_hparams(hparams) return est_lib.get_gpu_estimator(generator, discriminator, hparams, config)
def test_get_gpu_estimator_syntax(self): config = estimator_lib.get_run_config_from_hparams(self.hparams) est = estimator_lib.get_gpu_estimator( generator, discriminator, self.hparams, config) est.evaluate(lambda: input_fn({'batch_size': 16}), steps=1)