def test_get_batch_size(self): config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]} # batch_size_schedule is expanded to [8, 8, 8, 4, 2] # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2] for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]): self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)
def test_train_success(self): train_root_dir = self._config['train_root_dir'] if not tf.gfile.Exists(train_root_dir): tf.gfile.MakeDirs(train_root_dir) for stage_id in train.get_stage_ids(**self._config): batch_size = train.get_batch_size(stage_id, **self._config) tf.reset_default_graph() real_images = provide_random_data(batch_size=batch_size) model = train.build_model(stage_id, batch_size, real_images, **self._config) train.add_model_summaries(model, **self._config) train.train(model, **self._config)
def main(_): if not tf.gfile.Exists(FLAGS.train_root_dir): tf.gfile.MakeDirs(FLAGS.train_root_dir) config = _make_config_from_flags() LOGGING.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()])) for stage_id in train.get_stage_ids(**config): batch_size = train.get_batch_size(stage_id, **config) tf.reset_default_graph() with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): real_images = None with tf.device('/cpu:0'), tf.name_scope('inputs'): real_images = _provide_real_images(batch_size, **config) model = train.build_model(stage_id, batch_size, real_images, **config) train.add_model_summaries(model, **config) train.train(model, **config)
def main(_): if not tf.gfile.Exists(FLAGS.train_root_dir): tf.gfile.MakeDirs(FLAGS.train_root_dir) config = _make_config_from_flags() logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()])) for stage_id in train.get_stage_ids(**config): batch_size = train.get_batch_size(stage_id, **config) tf.reset_default_graph() with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): real_images = None with tf.device('/cpu:0'), tf.name_scope('inputs'): real_images = _provide_real_images(batch_size, **config) model = train.build_model(stage_id, batch_size, real_images, **config) train.add_model_summaries(model, **config) train.train(model, **config)
def main(_): logging.info("Setting up directory") if not tf.gfile.Exists(FLAGS.train_root_dir): tf.gfile.MakeDirs(FLAGS.train_root_dir) logging.info("Set up logging") config = _make_config_from_flags() logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.items()])) for stage_id in train.get_stage_ids(**config):#change stage_id to restore from latest model ###IMPORTANT logging.info("Get batch_size") batch_size = train.get_batch_size(stage_id, **config) tf.reset_default_graph() with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): real_images = None logging.info("Setup data") with tf.device('/cpu:0'), tf.name_scope('inputs'): real_images = _provide_real_images(batch_size, **config) logging.info("Building model") model = train.build_model(stage_id, batch_size, real_images, **config) logging.info("Adding Summaries") train.add_model_summaries(model, **config) train.train(model, **config)