Exemplo n.º 1
0
    def test_train_success(self):
        train_root_dir = self._config['train_root_dir']
        if not tf.gfile.Exists(train_root_dir):
            tf.gfile.MakeDirs(train_root_dir)

        for stage_id in train.get_stage_ids(**self._config):
            tf.reset_default_graph()
            real_images = provide_random_data()
            model = train.build_model(stage_id, real_images, **self._config)
            train.add_model_summaries(model, **self._config)
            train.train(model, **self._config)
Exemplo n.º 2
0
  def test_train_success(self):
    train_root_dir = self._config['train_root_dir']
    if not tf.gfile.Exists(train_root_dir):
      tf.gfile.MakeDirs(train_root_dir)

    for stage_id in train.get_stage_ids(**self._config):
      tf.reset_default_graph()
      real_images = provide_random_data()
      model = train.build_model(stage_id, real_images, **self._config)
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)
Exemplo n.º 3
0
def main(_):
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)

  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(**config)
      model = train.build_model(stage_id, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config)
Exemplo n.º 4
0
def main(_):
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)

  config = _make_config_from_flags()
  LOGGING.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()]))

  for stage_id in train.get_stage_ids(**config):
    batch_size = train.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      model = train.build_model(stage_id, batch_size, real_images, **config)
      train.add_model_summaries(model, **config)
      train.train(model, **config)
Exemplo n.º 5
0
def main(_):
  logging.info("Setting up directory")
  if not tf.gfile.Exists(FLAGS.train_root_dir):
    tf.gfile.MakeDirs(FLAGS.train_root_dir)
  logging.info("Set up logging")
  config = _make_config_from_flags()
  logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.items()]))

  for stage_id in train.get_stage_ids(**config):#change stage_id to restore from latest model ###IMPORTANT
    logging.info("Get batch_size")
    batch_size = train.get_batch_size(stage_id, **config)
    tf.reset_default_graph()
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
      real_images = None
      logging.info("Setup data")
      with tf.device('/cpu:0'), tf.name_scope('inputs'):
        real_images = _provide_real_images(batch_size, **config)
      logging.info("Building model")
      model = train.build_model(stage_id, batch_size, real_images, **config)
      logging.info("Adding Summaries")
      train.add_model_summaries(model, **config)
      train.train(model, **config)