示例#1
0
文件: train_test.py 项目: srkm009/gan
 def test_get_batch_size(self):
     config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
     # batch_size_schedule is expanded to [8, 8, 8, 4, 2]
     # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
     for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
         self.assertEqual(train.get_batch_size(i, **config),
                          expected_batch_size)
示例#2
0
文件: train_test.py 项目: srkm009/gan
    def test_train_success(self):
        train_log_dir = self._config['train_log_dir']
        if not tf.io.gfile.exists(train_log_dir):
            tf.io.gfile.makedirs(train_log_dir)

        for stage_id in train.get_stage_ids(**self._config):
            batch_size = train.get_batch_size(stage_id, **self._config)
            tf.compat.v1.reset_default_graph()
            real_images = provide_random_data(batch_size=batch_size)
            model = train.build_model(stage_id, batch_size, real_images,
                                      **self._config)
            train.add_model_summaries(model, **self._config)
            train.train(model, **self._config)
示例#3
0
  def test_train_success(self):
    if tf.executing_eagerly():
      # `tfgan.gan_model` doesn't work when executing eagerly.
      return
    train_log_dir = self._config['train_log_dir']
    if not tf.io.gfile.exists(train_log_dir):
      tf.io.gfile.makedirs(train_log_dir)

    for stage_id in train.get_stage_ids(**self._config):
      batch_size = train.get_batch_size(stage_id, **self._config)
      tf.reset_default_graph()
      real_images = provide_random_data(batch_size=batch_size)
      model = train.build_model(stage_id, batch_size, real_images,
                                **self._config)
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)
示例#4
0
文件: train_main.py 项目: yyht/gan
def main(_):
  if not tf.io.gfile.exists(FLAGS.train_log_dir):
    tf.io.gfile.makedirs(FLAGS.train_log_dir)

  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    batch_size = train.get_batch_size(stage_id, **config)
    tf.compat.v1.reset_default_graph()
    with tf.device(tf.compat.v1.train.replica_device_setter(FLAGS.ps_replicas)):
      real_images = None
      with tf.device('/cpu:0'), tf.compat.v1.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      model = train.build_model(stage_id, batch_size, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config)