Example #1
0
 def test_get_batch_size(self):
     config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
     # batch_size_schedule is expanded to [8, 8, 8, 4, 2]
     # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
     for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
         self.assertEqual(train.get_batch_size(i, **config),
                          expected_batch_size)
Example #2
0
  def test_train_success(self):
    train_root_dir = self._config['train_root_dir']
    if not tf.gfile.Exists(train_root_dir):
      tf.gfile.MakeDirs(train_root_dir)

    for stage_id in train.get_stage_ids(**self._config):
      batch_size = train.get_batch_size(stage_id, **self._config)
      tf.reset_default_graph()
      real_images = provide_random_data(batch_size=batch_size)
      model = train.build_model(stage_id, batch_size, real_images,
                                **self._config)
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)
Example #3
0
def main(_):
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)

  config = _make_config_from_flags()
  LOGGING.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    batch_size = train.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      model = train.build_model(stage_id, batch_size, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config)
Example #4
0
def main(_):
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)

  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    batch_size = train.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      model = train.build_model(stage_id, batch_size, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config)
Example #5
0
def main(_):
  logging.info("Setting up directory")
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)
  logging.info("Set up logging")
  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.items()]))

  for stage_id in train.get_stage_ids(**config):#change stage_id to restore from latest model ###IMPORTANT
    logging.info("Get batch_size")
    batch_size = train.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      logging.info("Setup data")
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      logging.info("Building model")
      model = train.build_model(stage_id, batch_size, real_images, **config)
      logging.info("Adding Summaries")
      train.add_model_summaries(model, **config)
      train.train(model, **config)
Example #6
0
 def test_get_batch_size(self):
   config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
   # batch_size_schedule is expanded to [8, 8, 8, 4, 2]
   # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
   for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
     self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)