def test_train_success(self): train_log_dir = self._config['train_log_dir'] if not tf.io.gfile.exists(train_log_dir): tf.io.gfile.makedirs(train_log_dir) for stage_id in train.get_stage_ids(**self._config): batch_size = train.get_batch_size(stage_id, **self._config) tf.compat.v1.reset_default_graph() real_images = provide_random_data(batch_size=batch_size) model = train.build_model(stage_id, batch_size, real_images, **self._config) train.add_model_summaries(model, **self._config) train.train(model, **self._config)
def test_train_success(self): if tf.executing_eagerly(): # `tfgan.gan_model` doesn't work when executing eagerly. return train_log_dir = self._config['train_log_dir'] if not tf.io.gfile.exists(train_log_dir): tf.io.gfile.makedirs(train_log_dir) for stage_id in train.get_stage_ids(**self._config): batch_size = train.get_batch_size(stage_id, **self._config) tf.reset_default_graph() real_images = provide_random_data(batch_size=batch_size) model = train.build_model(stage_id, batch_size, real_images, **self._config) train.add_model_summaries(model, **self._config) train.train(model, **self._config)
def main(_): if not tf.io.gfile.exists(FLAGS.train_log_dir): tf.io.gfile.makedirs(FLAGS.train_log_dir) config = _make_config_from_flags() logging.info('\n'.join(['{}={}'.format(k, v) for k, v in config.iteritems()])) for stage_id in train.get_stage_ids(**config): batch_size = train.get_batch_size(stage_id, **config) tf.compat.v1.reset_default_graph() with tf.device(tf.compat.v1.train.replica_device_setter(FLAGS.ps_replicas)): real_images = None with tf.device('/cpu:0'), tf.compat.v1.name_scope('inputs'): real_images = _provide_real_images(batch_size, **config) model = train.build_model(stage_id, batch_size, real_images, **config) train.add_model_summaries(model, **config) train.train(model, **config)