Example #1
0
def main(_):
  hparams = train_experiment.HParams(
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size,
      use_tpu=FLAGS.use_tpu,
      eval_on_tpu=FLAGS.eval_on_tpu,
      generator_lr=FLAGS.generator_lr,
      discriminator_lr=FLAGS.discriminator_lr,
      beta1=FLAGS.beta1,
      gf_dim=FLAGS.gf_dim,
      df_dim=FLAGS.df_dim,
      num_classes=1000,
      shuffle_buffer_size=10000,
      z_dim=FLAGS.z_dim,
      model_dir=FLAGS.model_dir,
      continuous_eval_timeout_secs=FLAGS.continuous_eval_timeout_secs,
      use_tpu_estimator=FLAGS.use_tpu_estimator,
      max_number_of_steps=FLAGS.max_number_of_steps,
      train_steps_per_eval=FLAGS.train_steps_per_eval,
      num_eval_steps=FLAGS.num_eval_steps,
      fake_nets=FLAGS.fake_nets,
      tpu_iterations_per_loop=FLAGS.tpu_iterations_per_loop,
  )
  if FLAGS.mode == 'train':
    train_experiment.run_train(hparams)
  elif FLAGS.mode == 'continuous_eval':
    train_experiment.run_continuous_eval(hparams)
  elif FLAGS.mode == 'train_and_eval' or FLAGS.mode is None:
    train_experiment.run_train_and_eval(hparams)
  else:
    raise ValueError('Mode not recognized: ', FLAGS.mode)
Example #2
0
  def test_run_train_cpu_local(self, mock_metrics, tpu_est):
    """Tests `run_train`."""
    self.hparams = self.hparams._replace(use_tpu_estimator=tpu_est)

    # Mock computationally expensive metrics computations.
    mock_metrics.return_value = {}

    train_experiment.run_train(self.hparams)
Example #3
0
def main(_):
    from tensorflow_gan.examples.self_attention_estimator import train_experiment

    # get TF logger
    log = logging.getLogger('tensorflow')
    log.setLevel(logging.INFO)

    # create formatter and add it to the handlers
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # create file handler
    logging_dir = os.environ['HOME'] if FLAGS.use_tpu else FLAGS.model_dir
    if not os.path.isdir(logging_dir):
        os.makedirs(logging_dir)
    fh = logging.FileHandler(logging_dir + '/tensorflow.log')
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    log.addHandler(fh)

    tpu_location = FLAGS.tpu
    if FLAGS.use_tpu:
        assert ',' not in tpu_location, 'Only using 1 TPU is supported'
    hparams = train_experiment.HParams(
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size,
        generator_lr=FLAGS.generator_lr,
        discriminator_lr=FLAGS.discriminator_lr,
        beta1=FLAGS.beta1,
        gf_dim=FLAGS.gf_dim,
        df_dim=FLAGS.df_dim,
        num_classes=FLAGS.num_classes,
        shuffle_buffer_size=10000,
        z_dim=FLAGS.z_dim,
        model_dir=FLAGS.model_dir,
        max_number_of_steps=FLAGS.max_number_of_steps,
        train_steps_per_eval=FLAGS.train_steps_per_eval,
        num_eval_steps=FLAGS.num_eval_steps,
        debug_params=train_experiment.DebugParams(
            use_tpu=FLAGS.use_tpu,
            eval_on_tpu=FLAGS.eval_on_tpu,
            fake_nets=False,
            fake_data=False,
            continuous_eval_timeout_secs=FLAGS.continuous_eval_timeout_secs,
        ),
        tpu_params=train_experiment.TPUParams(
            use_tpu_estimator=FLAGS.use_tpu_estimator,
            tpu_location=tpu_location,
            gcp_project=FLAGS.gcp_project,
            tpu_zone=FLAGS.tpu_zone,
            tpu_iterations_per_loop=FLAGS.tpu_iterations_per_loop,
        ),
    )
    if FLAGS.mode == 'train':
        train_experiment.run_train(hparams)
    elif FLAGS.mode == 'continuous_eval':
        train_experiment.run_continuous_eval(hparams)
    elif FLAGS.mode == 'intra_fid_eval':
        train_experiment.run_intra_fid_eval(hparams)
    elif FLAGS.mode == 'train_and_eval' or FLAGS.mode is None:
        train_experiment.run_train_and_eval(hparams)
    elif FLAGS.mode == 'gen_images':
        train_experiment.gen_images(hparams)
    elif FLAGS.mode == 'gen_matrices':
        train_experiment.gen_matrices(hparams)
    else:
        raise ValueError('Mode not recognized: ', FLAGS.mode)
Example #4
0
 def test_run_train_cpu_local_gpuestimator(self):
     """Tests `run_train`."""
     train_experiment.run_train(self.hparams)