Example #1
0
def gen_matrices(hparams):
    """..."""
    tf.compat.v1.logging.info('Generating Matrices.')

    # modified body of make_estimator(hparams)
    discriminator = _get_discriminator(hparams)
    generator = _get_generator_to_be_conditioned(hparams)

    if hparams.tpu_params.use_tpu_estimator:
        config = est_lib.get_tpu_run_config_from_hparams(hparams)
        estimator = est_lib.get_tpu_estimator(generator, discriminator,
                                              hparams, config)
    else:
        config = est_lib.get_run_config_from_hparams(hparams)
        estimator = est_lib.get_gpu_estimator(generator, discriminator,
                                              hparams, config)

    ckpt_str = evaluation.latest_checkpoint(hparams.model_dir)
    tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str)

    # saving matrices
    save_dir = os.environ['HOME'] if flags.FLAGS.use_tpu else hparams.model_dir
    embedding_map = estimator.get_variable_value(
        'Discriminator/discriminator/d_embedding/embedding_map')
    np.save(
        '%s/embedding_map_step_%s.npy' % (save_dir, ckpt_str.split('-')[-1]),
        embedding_map)
    class_kernel = 'Discriminator/discriminator/d_sn_linear_class/dense/kernel'
    if class_kernel in estimator.get_variable_names():
        classification_map = estimator.get_variable_value(class_kernel)
        np.save(
            '%s/classification_map_step_%s.npy' %
            (save_dir, ckpt_str.split('-')[-1]), classification_map)
Example #2
0
def run_intra_fid_eval(hparams):
    """..."""
    tf.compat.v1.logging.info('Intra FID evaluation.')

    # modified body of make_estimator(hparams)
    generator = _get_generator_to_be_conditioned(hparams)
    discriminator = _get_discriminator(hparams)

    if hparams.tpu_params.use_tpu_estimator:
        config = est_lib.get_tpu_run_config_from_hparams(hparams)
        estimator = est_lib.get_tpu_estimator(generator, discriminator,
                                              hparams, config)
    else:
        config = est_lib.get_run_config_from_hparams(hparams)
        estimator = est_lib.get_gpu_estimator(generator, discriminator,
                                              hparams, config)

    ckpt_str = evaluation.latest_checkpoint(hparams.model_dir)
    tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str)
    chunk_sz = flags.FLAGS.intra_fid_eval_chunk_size
    n_chunks = flags.FLAGS.num_classes // chunk_sz
    for chunk_i in range(0, n_chunks):
        restrict_classes = list(
            range(chunk_i * chunk_sz, (chunk_i + 1) * chunk_sz))
        limited_class_train_eval_input_fn = functools.partial(
            train_eval_input_fn,
            restrict_classes=restrict_classes,
            shift_classes=chunk_i * chunk_sz)
        eval_results = estimator.evaluate(limited_class_train_eval_input_fn,
                                          steps=hparams.num_eval_steps,
                                          name='eval_intra_fid')
        tf.compat.v1.logging.info(
            'Finished intra fid {}/{} evaluation checkpoint: {}. IFID: {}'.
            format(chunk_i, n_chunks, ckpt_str,
                   eval_results['eval/intra_fid']))
Example #3
0
def gen_images(hparams):
    """..."""
    tf.compat.v1.logging.info('Generating Images.')

    # modified body of make_estimator(hparams)
    discriminator = _get_discriminator(hparams)
    generator = _get_generator_to_be_conditioned(hparams)

    if hparams.tpu_params.use_tpu_estimator:
        config = est_lib.get_tpu_run_config_from_hparams(hparams)
        estimator = est_lib.get_tpu_estimator(generator, discriminator,
                                              hparams, config)
    else:
        config = est_lib.get_run_config_from_hparams(hparams)
        estimator = est_lib.get_gpu_estimator(generator, discriminator,
                                              hparams, config)

    # tf.compat.v1.logging.info('Counting params...')
    # total_parameters = 0
    # for variable in estimator.get_variable_names():
    #   vval = estimator.get_variable_value(variable)
    #   nparam = np.prod(estimator.get_variable_value(variable).shape)
    #   total_parameters += int(nparam)
    # tf.compat.v1.logging.info('Found %i params.' % total_parameters)
    # print(total_parameters)

    ckpt_str = evaluation.latest_checkpoint(hparams.model_dir)
    tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str)

    try:
        cur_step = int(estimator.get_variable_value('global_step'))
    except ValueError:
        cur_step = 0
    eval_lib.predict_and_write_images(estimator, train_eval_input_fn,
                                      hparams.model_dir, 'step_%i' % cur_step)
Example #4
0
def make_estimator(hparams):
  """Creates a TPU Estimator."""
  generator = _get_generator(hparams)
  discriminator = _get_discriminator(hparams)

  if hparams.tpu_params.use_tpu_estimator:
    config = est_lib.get_tpu_run_config_from_hparams(hparams)
    return est_lib.get_tpu_estimator(generator, discriminator, hparams, config)
  else:
    config = est_lib.get_run_config_from_hparams(hparams)
    return est_lib.get_gpu_estimator(generator, discriminator, hparams, config)
Example #5
0
 def test_get_gpu_estimator_syntax(self):
   config = estimator_lib.get_run_config_from_hparams(self.hparams)
   est = estimator_lib.get_gpu_estimator(
       generator, discriminator, self.hparams, config)
   est.evaluate(lambda: input_fn({'batch_size': 16}), steps=1)