Example #1
0
def gen_images(hparams):
    """..."""
    tf.compat.v1.logging.info('Generating Images.')

    # modified body of make_estimator(hparams)
    discriminator = _get_discriminator(hparams)
    generator = _get_generator_to_be_conditioned(hparams)

    if hparams.tpu_params.use_tpu_estimator:
        config = est_lib.get_tpu_run_config_from_hparams(hparams)
        estimator = est_lib.get_tpu_estimator(generator, discriminator,
                                              hparams, config)
    else:
        config = est_lib.get_run_config_from_hparams(hparams)
        estimator = est_lib.get_gpu_estimator(generator, discriminator,
                                              hparams, config)

    # tf.compat.v1.logging.info('Counting params...')
    # total_parameters = 0
    # for variable in estimator.get_variable_names():
    #   vval = estimator.get_variable_value(variable)
    #   nparam = np.prod(estimator.get_variable_value(variable).shape)
    #   total_parameters += int(nparam)
    # tf.compat.v1.logging.info('Found %i params.' % total_parameters)
    # print(total_parameters)

    ckpt_str = evaluation.latest_checkpoint(hparams.model_dir)
    tf.compat.v1.logging.info('Evaluating checkpoint: %s' % ckpt_str)

    try:
        cur_step = int(estimator.get_variable_value('global_step'))
    except ValueError:
        cur_step = 0
    eval_lib.predict_and_write_images(estimator, train_eval_input_fn,
                                      hparams.model_dir, 'step_%i' % cur_step)
Example #2
0
def run_train_and_eval(hparams):
    """Configure and run the train and estimator jobs."""
    estimator = make_estimator(hparams)

    # Recover from a previous step, if we've trained at all.
    try:
        cur_step = int(estimator.get_variable_value('global_step'))
    except ValueError:
        cur_step = 0

    max_step = hparams.max_number_of_steps
    steps_per_eval = hparams.train_steps_per_eval

    start_time = time.time()
    while cur_step < max_step:
        if hparams.tpu_params.use_tpu_estimator:
            tf.compat.v1.logging.info(
                'About to write sample images at step: %i' % cur_step)
            eval_lib.predict_and_write_images(estimator, train_eval_input_fn,
                                              hparams.model_dir,
                                              'step_%i' % cur_step)

        # Train for a fixed number of steps.
        start_step = cur_step
        step_to_stop_at = min(cur_step + steps_per_eval, max_step)
        tf.compat.v1.logging.info('About to train to step: %i' %
                                  step_to_stop_at)
        start = time.time()
        estimator.train(train_eval_input_fn, max_steps=step_to_stop_at)
        end = time.time()
        cur_step = step_to_stop_at

        # Print some performance statistics.
        steps_taken = step_to_stop_at - start_step
        time_taken = end - start
        _log_performance_statistics(cur_step, steps_taken, time_taken,
                                    start_time)

        # Run evaluation.
        tf.compat.v1.logging.info('Evaluating at step: %i' % cur_step)
        estimator.evaluate(train_eval_input_fn,
                           steps=hparams.num_eval_steps,
                           name='eval')
        tf.compat.v1.logging.info('Finished evaluating step: %i' % cur_step)