Пример #1
0
    def test_train_success(self):
        train_log_dir = self._config['train_log_dir']
        if not tf.io.gfile.exists(train_log_dir):
            tf.io.gfile.makedirs(train_log_dir)

        for stage_id in train.get_stage_ids(**self._config):
            batch_size = train.get_batch_size(stage_id, **self._config)
            tf.compat.v1.reset_default_graph()
            real_images = provide_random_data(batch_size=batch_size)
            model = train.build_model(stage_id, batch_size, real_images,
                                      **self._config)
            train.add_model_summaries(model, **self._config)
            train.train(model, **self._config)
Пример #2
0
  def test_train_success(self):
    if tf.executing_eagerly():
      # `tfgan.gan_model` doesn't work when executing eagerly.
      return
    train_log_dir = self._config['train_log_dir']
    if not tf.io.gfile.exists(train_log_dir):
      tf.io.gfile.makedirs(train_log_dir)

    for stage_id in train.get_stage_ids(**self._config):
      batch_size = train.get_batch_size(stage_id, **self._config)
      tf.reset_default_graph()
      real_images = provide_random_data(batch_size=batch_size)
      model = train.build_model(stage_id, batch_size, real_images,
                                **self._config)
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)
Пример #3
0
def main(_):
  if not tf.io.gfile.exists(FLAGS.train_log_dir):
    tf.io.gfile.makedirs(FLAGS.train_log_dir)

  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    batch_size = train.get_batch_size(stage_id, **config)
    tf.compat.v1.reset_default_graph()
    with tf.device(tf.compat.v1.train.replica_device_setter(FLAGS.ps_replicas)):
      real_images = None
      with tf.device('/cpu:0'), tf.compat.v1.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      model = train.build_model(stage_id, batch_size, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config)